mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-04 10:01:05 +01:00 
			
		
		
		
	Make TLS setup work automatically
This commit injects the per-test-generated tls certs into the tailscale container and makes sure all can ping all. It does not test any of the DERP isolation yet. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									89ff5c83d2
								
							
						
					
					
						commit
						9bc6ac0f35
					
				@ -1,52 +1,83 @@
 | 
				
			|||||||
package hsic
 | 
					package hsic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"archive/tar"
 | 
					 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"crypto/rsa"
 | 
				
			||||||
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"crypto/x509"
 | 
				
			||||||
 | 
						"crypto/x509/pkix"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"encoding/pem"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
 | 
						"math/big"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"path/filepath"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/juanfont/headscale"
 | 
						"github.com/juanfont/headscale"
 | 
				
			||||||
	v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
						v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
 | 
				
			||||||
	"github.com/juanfont/headscale/integration/dockertestutil"
 | 
						"github.com/juanfont/headscale/integration/dockertestutil"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/integration/integrationutil"
 | 
				
			||||||
	"github.com/ory/dockertest/v3"
 | 
						"github.com/ory/dockertest/v3"
 | 
				
			||||||
	"github.com/ory/dockertest/v3/docker"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	hsicHashLength    = 6
 | 
						hsicHashLength       = 6
 | 
				
			||||||
	dockerContextPath = "../."
 | 
						dockerContextPath    = "../."
 | 
				
			||||||
	aclPolicyPath     = "/etc/headscale/acl.hujson"
 | 
						aclPolicyPath        = "/etc/headscale/acl.hujson"
 | 
				
			||||||
 | 
						tlsCertPath          = "/etc/headscale/tls.cert"
 | 
				
			||||||
 | 
						tlsKeyPath           = "/etc/headscale/tls.key"
 | 
				
			||||||
 | 
						headscaleDefaultPort = 8080
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
 | 
					var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type HeadscaleInContainer struct {
 | 
					type HeadscaleInContainer struct {
 | 
				
			||||||
	hostname string
 | 
						hostname string
 | 
				
			||||||
	port     int
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pool      *dockertest.Pool
 | 
						pool      *dockertest.Pool
 | 
				
			||||||
	container *dockertest.Resource
 | 
						container *dockertest.Resource
 | 
				
			||||||
	network   *dockertest.Network
 | 
						network   *dockertest.Network
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// optional config
 | 
						// optional config
 | 
				
			||||||
 | 
						port      int
 | 
				
			||||||
	aclPolicy *headscale.ACLPolicy
 | 
						aclPolicy *headscale.ACLPolicy
 | 
				
			||||||
	env       []string
 | 
						env       []string
 | 
				
			||||||
 | 
						tlsCert   []byte
 | 
				
			||||||
 | 
						tlsKey    []byte
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Option = func(c *HeadscaleInContainer)
 | 
					type Option = func(c *HeadscaleInContainer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func WithACLPolicy(acl *headscale.ACLPolicy) Option {
 | 
					func WithACLPolicy(acl *headscale.ACLPolicy) Option {
 | 
				
			||||||
	return func(hsic *HeadscaleInContainer) {
 | 
						return func(hsic *HeadscaleInContainer) {
 | 
				
			||||||
 | 
							// TODO(kradalby): Move somewhere appropriate
 | 
				
			||||||
 | 
							hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_ACL_POLICY_PATH=%s", aclPolicyPath))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		hsic.aclPolicy = acl
 | 
							hsic.aclPolicy = acl
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithTLS() Option {
 | 
				
			||||||
 | 
						return func(hsic *HeadscaleInContainer) {
 | 
				
			||||||
 | 
							cert, key, err := createCertificate()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("failed to create certificates for headscale test: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// TODO(kradalby): Move somewhere appropriate
 | 
				
			||||||
 | 
							hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_TLS_CERT_PATH=%s", tlsCertPath))
 | 
				
			||||||
 | 
							hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_TLS_KEY_PATH=%s", tlsKeyPath))
 | 
				
			||||||
 | 
							hsic.env = append(hsic.env, "HEADSCALE_TLS_CLIENT_AUTH_MODE=disabled")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							hsic.tlsCert = cert
 | 
				
			||||||
 | 
							hsic.tlsKey = key
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func WithConfigEnv(configEnv map[string]string) Option {
 | 
					func WithConfigEnv(configEnv map[string]string) Option {
 | 
				
			||||||
	return func(hsic *HeadscaleInContainer) {
 | 
						return func(hsic *HeadscaleInContainer) {
 | 
				
			||||||
		env := []string{}
 | 
							env := []string{}
 | 
				
			||||||
@ -59,9 +90,14 @@ func WithConfigEnv(configEnv map[string]string) Option {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithPort(port int) Option {
 | 
				
			||||||
 | 
						return func(hsic *HeadscaleInContainer) {
 | 
				
			||||||
 | 
							hsic.port = port
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func New(
 | 
					func New(
 | 
				
			||||||
	pool *dockertest.Pool,
 | 
						pool *dockertest.Pool,
 | 
				
			||||||
	port int,
 | 
					 | 
				
			||||||
	network *dockertest.Network,
 | 
						network *dockertest.Network,
 | 
				
			||||||
	opts ...Option,
 | 
						opts ...Option,
 | 
				
			||||||
) (*HeadscaleInContainer, error) {
 | 
					) (*HeadscaleInContainer, error) {
 | 
				
			||||||
@ -71,11 +107,10 @@ func New(
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hostname := fmt.Sprintf("hs-%s", hash)
 | 
						hostname := fmt.Sprintf("hs-%s", hash)
 | 
				
			||||||
	portProto := fmt.Sprintf("%d/tcp", port)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hsic := &HeadscaleInContainer{
 | 
						hsic := &HeadscaleInContainer{
 | 
				
			||||||
		hostname: hostname,
 | 
							hostname: hostname,
 | 
				
			||||||
		port:     port,
 | 
							port:     headscaleDefaultPort,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		pool:    pool,
 | 
							pool:    pool,
 | 
				
			||||||
		network: network,
 | 
							network: network,
 | 
				
			||||||
@ -85,9 +120,7 @@ func New(
 | 
				
			|||||||
		opt(hsic)
 | 
							opt(hsic)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if hsic.aclPolicy != nil {
 | 
						portProto := fmt.Sprintf("%d/tcp", hsic.port)
 | 
				
			||||||
		hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_ACL_POLICY_PATH=%s", aclPolicyPath))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	headscaleBuildOptions := &dockertest.BuildOptions{
 | 
						headscaleBuildOptions := &dockertest.BuildOptions{
 | 
				
			||||||
		Dockerfile: "Dockerfile.debug",
 | 
							Dockerfile: "Dockerfile.debug",
 | 
				
			||||||
@ -144,9 +177,25 @@ func New(
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if hsic.hasTLS() {
 | 
				
			||||||
 | 
							err = hsic.WriteFile(tlsCertPath, hsic.tlsCert)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = hsic.WriteFile(tlsKeyPath, hsic.tlsKey)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return hsic, nil
 | 
						return hsic, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *HeadscaleInContainer) hasTLS() bool {
 | 
				
			||||||
 | 
						return len(t.tlsCert) != 0 && len(t.tlsKey) != 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *HeadscaleInContainer) Shutdown() error {
 | 
					func (t *HeadscaleInContainer) Shutdown() error {
 | 
				
			||||||
	return t.pool.Purge(t.container)
 | 
						return t.pool.Purge(t.container)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -183,11 +232,7 @@ func (t *HeadscaleInContainer) GetPort() string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *HeadscaleInContainer) GetHealthEndpoint() string {
 | 
					func (t *HeadscaleInContainer) GetHealthEndpoint() string {
 | 
				
			||||||
	hostEndpoint := fmt.Sprintf("%s:%d",
 | 
						return fmt.Sprintf("%s/health", t.GetEndpoint())
 | 
				
			||||||
		t.GetIP(),
 | 
					 | 
				
			||||||
		t.port)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return fmt.Sprintf("http://%s/health", hostEndpoint)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *HeadscaleInContainer) GetEndpoint() string {
 | 
					func (t *HeadscaleInContainer) GetEndpoint() string {
 | 
				
			||||||
@ -195,17 +240,39 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
 | 
				
			|||||||
		t.GetIP(),
 | 
							t.GetIP(),
 | 
				
			||||||
		t.port)
 | 
							t.port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if t.hasTLS() {
 | 
				
			||||||
 | 
							return fmt.Sprintf("https://%s", hostEndpoint)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return fmt.Sprintf("http://%s", hostEndpoint)
 | 
						return fmt.Sprintf("http://%s", hostEndpoint)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *HeadscaleInContainer) GetCert() []byte {
 | 
				
			||||||
 | 
						return t.tlsCert
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *HeadscaleInContainer) GetHostname() string {
 | 
				
			||||||
 | 
						return t.hostname
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *HeadscaleInContainer) WaitForReady() error {
 | 
					func (t *HeadscaleInContainer) WaitForReady() error {
 | 
				
			||||||
	url := t.GetHealthEndpoint()
 | 
						url := t.GetHealthEndpoint()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Printf("waiting for headscale to be ready at %s", url)
 | 
						log.Printf("waiting for headscale to be ready at %s", url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client := &http.Client{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if t.hasTLS() {
 | 
				
			||||||
 | 
							insecureTransport := http.DefaultTransport.(*http.Transport).Clone()
 | 
				
			||||||
 | 
							insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
 | 
				
			||||||
 | 
							client = &http.Client{Transport: insecureTransport}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return t.pool.Retry(func() error {
 | 
						return t.pool.Retry(func() error {
 | 
				
			||||||
		resp, err := http.Get(url) //nolint
 | 
							resp, err := client.Get(url) //nolint
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Printf("ready err: %s", err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return fmt.Errorf("headscale is not ready: %w", err)
 | 
								return fmt.Errorf("headscale is not ready: %w", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -292,55 +359,96 @@ func (t *HeadscaleInContainer) ListMachinesInNamespace(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *HeadscaleInContainer) WriteFile(path string, data []byte) error {
 | 
					func (t *HeadscaleInContainer) WriteFile(path string, data []byte) error {
 | 
				
			||||||
	dirPath, fileName := filepath.Split(path)
 | 
						return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	file := bytes.NewReader(data)
 | 
					func createCertificate() ([]byte, []byte, error) {
 | 
				
			||||||
 | 
						// From:
 | 
				
			||||||
 | 
						// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	buf := bytes.NewBuffer([]byte{})
 | 
						ca := &x509.Certificate{
 | 
				
			||||||
 | 
							SerialNumber: big.NewInt(2019),
 | 
				
			||||||
	tarWriter := tar.NewWriter(buf)
 | 
							Subject: pkix.Name{
 | 
				
			||||||
 | 
								Organization: []string{"Headscale testing INC"},
 | 
				
			||||||
	header := &tar.Header{
 | 
								Country:      []string{"NL"},
 | 
				
			||||||
		Name: fileName,
 | 
								Locality:     []string{"Leiden"},
 | 
				
			||||||
		Size: file.Size(),
 | 
					 | 
				
			||||||
		// Mode:    int64(stat.Mode()),
 | 
					 | 
				
			||||||
		// ModTime: stat.ModTime(),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err := tarWriter.WriteHeader(header)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("failed write file header to tar: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err = io.Copy(tarWriter, file)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("failed to copy file to tar: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = tarWriter.Close()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("failed to close tar: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	log.Printf("tar: %s", buf.String())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Ensure the directory is present inside the container
 | 
					 | 
				
			||||||
	_, err = t.Execute([]string{"mkdir", "-p", dirPath})
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return fmt.Errorf("failed to ensure directory: %w", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = t.pool.Client.UploadToContainer(
 | 
					 | 
				
			||||||
		t.container.Container.ID,
 | 
					 | 
				
			||||||
		docker.UploadToContainerOptions{
 | 
					 | 
				
			||||||
			NoOverwriteDirNonDir: false,
 | 
					 | 
				
			||||||
			Path:                 dirPath,
 | 
					 | 
				
			||||||
			InputStream:          bytes.NewReader(buf.Bytes()),
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							NotBefore: time.Now(),
 | 
				
			||||||
 | 
							NotAfter:  time.Now().Add(30 * time.Minute),
 | 
				
			||||||
 | 
							IsCA:      true,
 | 
				
			||||||
 | 
							ExtKeyUsage: []x509.ExtKeyUsage{
 | 
				
			||||||
 | 
								x509.ExtKeyUsageClientAuth,
 | 
				
			||||||
 | 
								x509.ExtKeyUsageServerAuth,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
 | 
				
			||||||
 | 
							BasicConstraintsValid: true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
 | 
				
			||||||
 | 
						// if err != nil {
 | 
				
			||||||
 | 
						// 	return nil, err
 | 
				
			||||||
 | 
						// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cert := &x509.Certificate{
 | 
				
			||||||
 | 
							SerialNumber: big.NewInt(1658),
 | 
				
			||||||
 | 
							Subject: pkix.Name{
 | 
				
			||||||
 | 
								Organization: []string{"Headscale testing INC"},
 | 
				
			||||||
 | 
								Country:      []string{"NL"},
 | 
				
			||||||
 | 
								Locality:     []string{"Leiden"},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							IPAddresses:  []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
 | 
				
			||||||
 | 
							NotBefore:    time.Now(),
 | 
				
			||||||
 | 
							NotAfter:     time.Now().Add(30 * time.Minute),
 | 
				
			||||||
 | 
							SubjectKeyId: []byte{1, 2, 3, 4, 6},
 | 
				
			||||||
 | 
							ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
 | 
				
			||||||
 | 
							KeyUsage:     x509.KeyUsageDigitalSignature,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						certBytes, err := x509.CreateCertificate(
 | 
				
			||||||
 | 
							rand.Reader,
 | 
				
			||||||
 | 
							cert,
 | 
				
			||||||
 | 
							ca,
 | 
				
			||||||
 | 
							&certPrivKey.PublicKey,
 | 
				
			||||||
 | 
							caPrivKey,
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return nil, nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						certPEM := new(bytes.Buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = pem.Encode(certPEM, &pem.Block{
 | 
				
			||||||
 | 
							Type:  "CERTIFICATE",
 | 
				
			||||||
 | 
							Bytes: certBytes,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						certPrivKeyPEM := new(bytes.Buffer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = pem.Encode(certPrivKeyPEM, &pem.Block{
 | 
				
			||||||
 | 
							Type:  "RSA PRIVATE KEY",
 | 
				
			||||||
 | 
							Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
 | 
				
			||||||
 | 
						// if err != nil {
 | 
				
			||||||
 | 
						// 	return nil, err
 | 
				
			||||||
 | 
						// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										77
									
								
								integration/integrationutil/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								integration/integrationutil/util.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
				
			|||||||
 | 
					package integrationutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"archive/tar"
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"log"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/integration/dockertestutil"
 | 
				
			||||||
 | 
						"github.com/ory/dockertest/v3"
 | 
				
			||||||
 | 
						"github.com/ory/dockertest/v3/docker"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WriteFileToContainer(
 | 
				
			||||||
 | 
						pool *dockertest.Pool,
 | 
				
			||||||
 | 
						container *dockertest.Resource,
 | 
				
			||||||
 | 
						path string,
 | 
				
			||||||
 | 
						data []byte,
 | 
				
			||||||
 | 
					) error {
 | 
				
			||||||
 | 
						dirPath, fileName := filepath.Split(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						file := bytes.NewReader(data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						buf := bytes.NewBuffer([]byte{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tarWriter := tar.NewWriter(buf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						header := &tar.Header{
 | 
				
			||||||
 | 
							Name: fileName,
 | 
				
			||||||
 | 
							Size: file.Size(),
 | 
				
			||||||
 | 
							// Mode:    int64(stat.Mode()),
 | 
				
			||||||
 | 
							// ModTime: stat.ModTime(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := tarWriter.WriteHeader(header)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed write file header to tar: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err = io.Copy(tarWriter, file)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to copy file to tar: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = tarWriter.Close()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to close tar: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.Printf("tar: %s", buf.String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Ensure the directory is present inside the container
 | 
				
			||||||
 | 
						_, _, err = dockertestutil.ExecuteCommand(
 | 
				
			||||||
 | 
							container,
 | 
				
			||||||
 | 
							[]string{"mkdir", "-p", dirPath},
 | 
				
			||||||
 | 
							[]string{},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("failed to ensure directory: %w", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = pool.Client.UploadToContainer(
 | 
				
			||||||
 | 
							container.Container.ID,
 | 
				
			||||||
 | 
							docker.UploadToContainerOptions{
 | 
				
			||||||
 | 
								NoOverwriteDirNonDir: false,
 | 
				
			||||||
 | 
								Path:                 dirPath,
 | 
				
			||||||
 | 
								InputStream:          bytes.NewReader(buf.Bytes()),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -150,20 +150,8 @@ func (s *Scenario) Namespaces() []string {
 | 
				
			|||||||
// Note: These functions assume that there is a _single_ headscale instance for now
 | 
					// Note: These functions assume that there is a _single_ headscale instance for now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(kradalby): make port and headscale configurable, multiple instances support?
 | 
					// TODO(kradalby): make port and headscale configurable, multiple instances support?
 | 
				
			||||||
func (s *Scenario) StartHeadscale() error {
 | 
					func (s *Scenario) StartHeadscale(opts ...hsic.Option) error {
 | 
				
			||||||
	headscale, err := hsic.New(s.pool, headscalePort, s.network,
 | 
						headscale, err := hsic.New(s.pool, s.network, opts...)
 | 
				
			||||||
		hsic.WithACLPolicy(
 | 
					 | 
				
			||||||
			&headscale.ACLPolicy{
 | 
					 | 
				
			||||||
				ACLs: []headscale.ACL{
 | 
					 | 
				
			||||||
					{
 | 
					 | 
				
			||||||
						Action:       "accept",
 | 
					 | 
				
			||||||
						Sources:      []string{"*"},
 | 
					 | 
				
			||||||
						Destinations: []string{"*:*"},
 | 
					 | 
				
			||||||
					},
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		),
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return fmt.Errorf("failed to create headscale container: %w", err)
 | 
							return fmt.Errorf("failed to create headscale container: %w", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -228,10 +216,22 @@ func (s *Scenario) CreateTailscaleNodesInNamespace(
 | 
				
			|||||||
				defer namespace.createWaitGroup.Done()
 | 
									defer namespace.createWaitGroup.Done()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// TODO(kradalby): error handle this
 | 
									// TODO(kradalby): error handle this
 | 
				
			||||||
				tsClient, err := tsic.New(s.pool, version, s.network)
 | 
									tsClient, err := tsic.New(
 | 
				
			||||||
 | 
										s.pool,
 | 
				
			||||||
 | 
										version,
 | 
				
			||||||
 | 
										s.network,
 | 
				
			||||||
 | 
										tsic.WithHeadscaleTLS(s.Headscale().GetCert()),
 | 
				
			||||||
 | 
										tsic.WithHeadscaleName(s.Headscale().GetHostname()),
 | 
				
			||||||
 | 
									)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					// return fmt.Errorf("failed to add tailscale node: %w", err)
 | 
										// return fmt.Errorf("failed to add tailscale node: %w", err)
 | 
				
			||||||
					log.Printf("failed to add tailscale node: %s", err)
 | 
										log.Printf("failed to create tailscale node: %s", err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									err = tsClient.WaitForReady()
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										// return fmt.Errorf("failed to add tailscale node: %w", err)
 | 
				
			||||||
 | 
										log.Printf("failed to wait for tailscaled: %s", err)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				namespace.Clients[tsClient.Hostname()] = tsClient
 | 
									namespace.Clients[tsClient.Hostname()] = tsClient
 | 
				
			||||||
@ -306,8 +306,8 @@ func (s *Scenario) WaitForTailscaleSync() error {
 | 
				
			|||||||
// CreateHeadscaleEnv is a conventient method returning a set up Headcale
 | 
					// CreateHeadscaleEnv is a conventient method returning a set up Headcale
 | 
				
			||||||
// test environment with nodes of all versions, joined to the server with X
 | 
					// test environment with nodes of all versions, joined to the server with X
 | 
				
			||||||
// namespaces.
 | 
					// namespaces.
 | 
				
			||||||
func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int) error {
 | 
					func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int, opts ...hsic.Option) error {
 | 
				
			||||||
	err := s.StartHeadscale()
 | 
						err := s.StartHeadscale(opts...)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -12,6 +12,7 @@ import (
 | 
				
			|||||||
	"github.com/cenkalti/backoff/v4"
 | 
						"github.com/cenkalti/backoff/v4"
 | 
				
			||||||
	"github.com/juanfont/headscale"
 | 
						"github.com/juanfont/headscale"
 | 
				
			||||||
	"github.com/juanfont/headscale/integration/dockertestutil"
 | 
						"github.com/juanfont/headscale/integration/dockertestutil"
 | 
				
			||||||
 | 
						"github.com/juanfont/headscale/integration/integrationutil"
 | 
				
			||||||
	"github.com/ory/dockertest/v3"
 | 
						"github.com/ory/dockertest/v3"
 | 
				
			||||||
	"github.com/ory/dockertest/v3/docker"
 | 
						"github.com/ory/dockertest/v3/docker"
 | 
				
			||||||
	"tailscale.com/ipn/ipnstate"
 | 
						"tailscale.com/ipn/ipnstate"
 | 
				
			||||||
@ -20,6 +21,7 @@ import (
 | 
				
			|||||||
const (
 | 
					const (
 | 
				
			||||||
	tsicHashLength    = 6
 | 
						tsicHashLength    = 6
 | 
				
			||||||
	dockerContextPath = "../."
 | 
						dockerContextPath = "../."
 | 
				
			||||||
 | 
						headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
@ -41,12 +43,51 @@ type TailscaleInContainer struct {
 | 
				
			|||||||
	// "cache"
 | 
						// "cache"
 | 
				
			||||||
	ips  []netip.Addr
 | 
						ips  []netip.Addr
 | 
				
			||||||
	fqdn string
 | 
						fqdn string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// optional config
 | 
				
			||||||
 | 
						headscaleCert     []byte
 | 
				
			||||||
 | 
						headscaleHostname string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Option = func(c *TailscaleInContainer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithHeadscaleTLS(cert []byte) Option {
 | 
				
			||||||
 | 
						return func(tsic *TailscaleInContainer) {
 | 
				
			||||||
 | 
							tsic.headscaleCert = cert
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithOrCreateNetwork(network *dockertest.Network) Option {
 | 
				
			||||||
 | 
						return func(tsic *TailscaleInContainer) {
 | 
				
			||||||
 | 
							if network != nil {
 | 
				
			||||||
 | 
								tsic.network = network
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							network, err := dockertestutil.GetFirstOrCreateNetwork(
 | 
				
			||||||
 | 
								tsic.pool,
 | 
				
			||||||
 | 
								fmt.Sprintf("%s-network", tsic.hostname),
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.Fatalf("failed to create network: %s", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							tsic.network = network
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WithHeadscaleName(hsName string) Option {
 | 
				
			||||||
 | 
						return func(tsic *TailscaleInContainer) {
 | 
				
			||||||
 | 
							tsic.headscaleHostname = hsName
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func New(
 | 
					func New(
 | 
				
			||||||
	pool *dockertest.Pool,
 | 
						pool *dockertest.Pool,
 | 
				
			||||||
	version string,
 | 
						version string,
 | 
				
			||||||
	network *dockertest.Network,
 | 
						network *dockertest.Network,
 | 
				
			||||||
 | 
						opts ...Option,
 | 
				
			||||||
) (*TailscaleInContainer, error) {
 | 
					) (*TailscaleInContainer, error) {
 | 
				
			||||||
	hash, err := headscale.GenerateRandomStringDNSSafe(tsicHashLength)
 | 
						hash, err := headscale.GenerateRandomStringDNSSafe(tsicHashLength)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -55,20 +96,38 @@ func New(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	hostname := fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
 | 
						hostname := fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO(kradalby): figure out why we need to "refresh" the network here.
 | 
						tsic := &TailscaleInContainer{
 | 
				
			||||||
	// network, err = dockertestutil.GetFirstOrCreateNetwork(pool, network.Network.Name)
 | 
							version:  version,
 | 
				
			||||||
	// if err != nil {
 | 
							hostname: hostname,
 | 
				
			||||||
	// 	return nil, err
 | 
					
 | 
				
			||||||
	// }
 | 
							pool:    pool,
 | 
				
			||||||
 | 
							network: network,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, opt := range opts {
 | 
				
			||||||
 | 
							opt(tsic)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tailscaleOptions := &dockertest.RunOptions{
 | 
						tailscaleOptions := &dockertest.RunOptions{
 | 
				
			||||||
		Name:     hostname,
 | 
							Name:     hostname,
 | 
				
			||||||
		Networks: []*dockertest.Network{network},
 | 
							Networks: []*dockertest.Network{network},
 | 
				
			||||||
		Cmd: []string{
 | 
							// Cmd: []string{
 | 
				
			||||||
			"tailscaled", "--tun=tsdev",
 | 
							// 	"tailscaled", "--tun=tsdev",
 | 
				
			||||||
 | 
							// },
 | 
				
			||||||
 | 
							Entrypoint: []string{
 | 
				
			||||||
 | 
								"/bin/bash",
 | 
				
			||||||
 | 
								"-c",
 | 
				
			||||||
 | 
								"/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if tsic.headscaleHostname != "" {
 | 
				
			||||||
 | 
							tailscaleOptions.ExtraHosts = []string{
 | 
				
			||||||
 | 
								"host.docker.internal:host-gateway",
 | 
				
			||||||
 | 
								fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// dockertest isnt very good at handling containers that has already
 | 
						// dockertest isnt very good at handling containers that has already
 | 
				
			||||||
	// been created, this is an attempt to make sure this container isnt
 | 
						// been created, this is an attempt to make sure this container isnt
 | 
				
			||||||
	// present.
 | 
						// present.
 | 
				
			||||||
@ -89,14 +148,20 @@ func New(
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	log.Printf("Created %s container\n", hostname)
 | 
						log.Printf("Created %s container\n", hostname)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &TailscaleInContainer{
 | 
						tsic.container = container
 | 
				
			||||||
		version:  version,
 | 
					 | 
				
			||||||
		hostname: hostname,
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		pool:      pool,
 | 
						if tsic.hasTLS() {
 | 
				
			||||||
		container: container,
 | 
							err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert)
 | 
				
			||||||
		network:   network,
 | 
							if err != nil {
 | 
				
			||||||
	}, nil
 | 
								return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return tsic, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *TailscaleInContainer) hasTLS() bool {
 | 
				
			||||||
 | 
						return len(t.headscaleCert) != 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TailscaleInContainer) Shutdown() error {
 | 
					func (t *TailscaleInContainer) Shutdown() error {
 | 
				
			||||||
@ -111,6 +176,19 @@ func (t *TailscaleInContainer) Version() string {
 | 
				
			|||||||
	return t.version
 | 
						return t.version
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *TailscaleInContainer) WaitForReady() error {
 | 
				
			||||||
 | 
						return t.pool.Retry(func() error {
 | 
				
			||||||
 | 
							// If tailscaled has not started yet, this will return a non-zero
 | 
				
			||||||
 | 
							// status code
 | 
				
			||||||
 | 
							_, err := t.Execute([]string{"tailscale", "status"})
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TailscaleInContainer) Execute(
 | 
					func (t *TailscaleInContainer) Execute(
 | 
				
			||||||
	command []string,
 | 
						command []string,
 | 
				
			||||||
) (string, string, error) {
 | 
					) (string, string, error) {
 | 
				
			||||||
@ -318,6 +396,10 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
 | 
				
			||||||
 | 
						return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func createTailscaleBuildOptions(version string) *dockertest.BuildOptions {
 | 
					func createTailscaleBuildOptions(version string) *dockertest.BuildOptions {
 | 
				
			||||||
	var tailscaleBuildOptions *dockertest.BuildOptions
 | 
						var tailscaleBuildOptions *dockertest.BuildOptions
 | 
				
			||||||
	switch version {
 | 
						switch version {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user