mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 12:16:44 +02:00
net,wgengine: add support for disco key exchnage via TSMP
Updates tailscale/corp#34037 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
parent
9eff8a4503
commit
5bfa8e97f6
@ -15,12 +15,10 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const minTSMPSize = 7 // the rejected body is 7 bytes
|
||||
const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
|
||||
|
||||
// TailscaleRejectedHeader is a TSMP message that says that one
|
||||
// Tailscale node has rejected the connection from another. Unlike a
|
||||
@ -75,8 +73,11 @@ const (
|
||||
// TSMPTypePong is the type byte for a TailscalePongResponse.
|
||||
TSMPTypePong TSMPType = 'o'
|
||||
|
||||
// TSPMTypeDiscoAdvertisement is the type byte for sending disco keys
|
||||
TSMPTypeDiscoAdvertisement TSMPType = 'a'
|
||||
// TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
|
||||
TSMPTypeDiscoKeyRequest TSMPType = 'd'
|
||||
|
||||
// TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
|
||||
TSMPTypeDiscoKeyUpdate TSMPType = 'D'
|
||||
)
|
||||
|
||||
type TailscaleRejectReason byte
|
||||
@ -265,52 +266,62 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys.
|
||||
// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
|
||||
//
|
||||
// On the wire, after the IP header, it's currently 33 bytes:
|
||||
// - 'a' (TSMPTypeDiscoAdvertisement)
|
||||
// - 32 disco key bytes
|
||||
type TSMPDiscoKeyAdvertisement struct {
|
||||
Src, Dst netip.Addr
|
||||
Key key.DiscoPublic
|
||||
}
|
||||
// On the wire, after the IP header, it's currently 1 byte:
|
||||
// - 'd' (TSMPTypeDiscoKeyRequest)
|
||||
type TSMPDiscoKeyRequest struct{}
|
||||
|
||||
func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
|
||||
var iph Header
|
||||
if ka.Src.Is4() {
|
||||
iph = IP4Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: ka.Src,
|
||||
Dst: ka.Dst,
|
||||
}
|
||||
} else {
|
||||
iph = IP6Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: ka.Src,
|
||||
Dst: ka.Dst,
|
||||
}
|
||||
}
|
||||
payload := make([]byte, 0, 33)
|
||||
payload = append(payload, byte(TSMPTypeDiscoAdvertisement))
|
||||
payload = ka.Key.AppendTo(payload)
|
||||
if len(payload) != 33 {
|
||||
// Mostly to safeguard against ourselves changing this in the future.
|
||||
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
|
||||
}
|
||||
|
||||
return Generate(iph, payload), nil
|
||||
}
|
||||
|
||||
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
|
||||
func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
|
||||
if pp.IPProto != ipproto.TSMP {
|
||||
return
|
||||
}
|
||||
p := pp.Payload()
|
||||
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) {
|
||||
if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
|
||||
return
|
||||
}
|
||||
tka.Src = pp.Src.Addr()
|
||||
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
|
||||
|
||||
return tka, true
|
||||
return h, true
|
||||
}
|
||||
|
||||
// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
|
||||
// It may be sent in response to a request, or unsolicited when a node
|
||||
// believes its peer may have stale disco key information.
|
||||
//
|
||||
// On the wire, after the IP header, it's currently 33 bytes:
|
||||
// - 'D' (TSMPTypeDiscoKeyUpdate)
|
||||
// - 32 bytes disco public key
|
||||
type TSMPDiscoKeyUpdate struct {
|
||||
IPHeader Header
|
||||
DiscoKey [32]byte // raw disco public key bytes
|
||||
}
|
||||
|
||||
// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
|
||||
// The update.IPHeader field is not populated.
|
||||
func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
|
||||
if pp.IPProto != ipproto.TSMP {
|
||||
return
|
||||
}
|
||||
p := pp.Payload()
|
||||
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
|
||||
return
|
||||
}
|
||||
copy(update.DiscoKey[:], p[1:33])
|
||||
return update, true
|
||||
}
|
||||
|
||||
func (h TSMPDiscoKeyUpdate) Len() int {
|
||||
return h.IPHeader.Len() + 33
|
||||
}
|
||||
|
||||
func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
|
||||
if len(buf) < h.Len() {
|
||||
return errSmallBuffer
|
||||
}
|
||||
if err := h.IPHeader.Marshal(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
buf = buf[h.IPHeader.Len():]
|
||||
buf[0] = byte(TSMPTypeDiscoKeyUpdate)
|
||||
copy(buf[1:33], h.DiscoKey[:])
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -4,14 +4,8 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestTailscaleRejectedHeader(t *testing.T) {
|
||||
@ -78,61 +72,168 @@ func TestTailscaleRejectedHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) {
|
||||
var (
|
||||
// IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum
|
||||
headerV4, _ = hex.DecodeString("45000035000000004063705d")
|
||||
// IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64)
|
||||
headerV6, _ = hex.DecodeString("6000000000216340")
|
||||
func TestTSMPDiscoKeyRequest(t *testing.T) {
|
||||
t.Run("Manual", func(t *testing.T) {
|
||||
var payload [1]byte
|
||||
payload[0] = byte(TSMPTypeDiscoKeyRequest)
|
||||
|
||||
packetType = []byte{'a'}
|
||||
testKey = bytes.Repeat([]byte{'a'}, 32)
|
||||
var p Parsed
|
||||
p.IPProto = TSMP
|
||||
p.dataofs = 40 // simulate after IP header
|
||||
buf := make([]byte, 40+1)
|
||||
copy(buf[40:], payload[:])
|
||||
p.b = buf
|
||||
p.length = len(buf)
|
||||
|
||||
// IPs
|
||||
srcV4 = netip.MustParseAddr("1.2.3.4")
|
||||
dstV4 = netip.MustParseAddr("4.3.2.1")
|
||||
srcV6 = netip.MustParseAddr("2001:db8::1")
|
||||
dstV6 = netip.MustParseAddr("2001:db8::2")
|
||||
)
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key request")
|
||||
}
|
||||
})
|
||||
|
||||
join := func(parts ...[]byte) []byte {
|
||||
return bytes.Join(parts, nil)
|
||||
}
|
||||
t.Run("RoundTripIPv4", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tka TSMPDiscoKeyAdvertisement
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "v4Header",
|
||||
tka: TSMPDiscoKeyAdvertisement{
|
||||
Src: srcV4,
|
||||
Dst: dstV4,
|
||||
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
|
||||
},
|
||||
want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey),
|
||||
},
|
||||
{
|
||||
name: "v6Header",
|
||||
tka: TSMPDiscoKeyAdvertisement{
|
||||
Src: srcV6,
|
||||
Dst: dstV6,
|
||||
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
|
||||
},
|
||||
want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey),
|
||||
},
|
||||
}
|
||||
iph := IP4Header{
|
||||
IPProto: TSMP,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.tka.Marshal()
|
||||
if err != nil {
|
||||
t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err)
|
||||
}
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
var payload [1]byte
|
||||
payload[0] = byte(TSMPTypeDiscoKeyRequest)
|
||||
|
||||
pkt := Generate(iph, payload[:])
|
||||
t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
|
||||
|
||||
// Manually check what decode4 would see
|
||||
if len(pkt) >= 4 {
|
||||
declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
|
||||
t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
|
||||
t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
|
||||
}
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
|
||||
p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
|
||||
|
||||
if p.IPVersion != 4 {
|
||||
t.Errorf("IPVersion = %d, want 4", p.IPVersion)
|
||||
}
|
||||
if p.IPProto != TSMP {
|
||||
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
|
||||
}
|
||||
if p.Src.Addr() != src {
|
||||
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
|
||||
}
|
||||
if p.Dst.Addr() != dst {
|
||||
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
|
||||
}
|
||||
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key request from generated packet")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RoundTripIPv6", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("2001:db8::1")
|
||||
dst := netip.MustParseAddr("2001:db8::2")
|
||||
|
||||
iph := IP6Header{
|
||||
IPProto: TSMP,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
}
|
||||
|
||||
var payload [1]byte
|
||||
payload[0] = byte(TSMPTypeDiscoKeyRequest)
|
||||
|
||||
pkt := Generate(iph, payload[:])
|
||||
t.Logf("Generated packet: %d bytes", len(pkt))
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
|
||||
if p.IPVersion != 6 {
|
||||
t.Errorf("IPVersion = %d, want 6", p.IPVersion)
|
||||
}
|
||||
if p.IPProto != TSMP {
|
||||
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
|
||||
}
|
||||
if p.Src.Addr() != src {
|
||||
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
|
||||
}
|
||||
if p.Dst.Addr() != dst {
|
||||
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
|
||||
}
|
||||
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key request from generated packet")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTSMPDiscoKeyUpdate(t *testing.T) {
|
||||
var discoKey [32]byte
|
||||
for i := range discoKey {
|
||||
discoKey[i] = byte(i + 10)
|
||||
}
|
||||
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
update := TSMPDiscoKeyUpdate{
|
||||
IPHeader: IP4Header{
|
||||
IPProto: TSMP,
|
||||
Src: netip.MustParseAddr("1.2.3.4"),
|
||||
Dst: netip.MustParseAddr("5.6.7.8"),
|
||||
},
|
||||
DiscoKey: discoKey,
|
||||
}
|
||||
|
||||
pkt := make([]byte, update.Len())
|
||||
if err := update.Marshal(pkt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
|
||||
parsed, ok := p.AsTSMPDiscoKeyUpdate()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key update")
|
||||
}
|
||||
if parsed.DiscoKey != discoKey {
|
||||
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
update := TSMPDiscoKeyUpdate{
|
||||
IPHeader: IP6Header{
|
||||
IPProto: TSMP,
|
||||
Src: netip.MustParseAddr("2001:db8::1"),
|
||||
Dst: netip.MustParseAddr("2001:db8::2"),
|
||||
},
|
||||
DiscoKey: discoKey,
|
||||
}
|
||||
|
||||
pkt := make([]byte, update.Len())
|
||||
if err := update.Marshal(pkt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var p Parsed
|
||||
p.Decode(pkt)
|
||||
|
||||
parsed, ok := p.AsTSMPDiscoKeyUpdate()
|
||||
if !ok {
|
||||
t.Fatal("failed to parse TSMP disco key update")
|
||||
}
|
||||
if parsed.DiscoKey != discoKey {
|
||||
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -194,6 +194,10 @@ type Wrapper struct {
|
||||
// false otherwise.
|
||||
OnICMPEchoResponseReceived func(*packet.Parsed) bool
|
||||
|
||||
// GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
|
||||
// This is called when responding to TSMP disco key requests.
|
||||
GetDiscoPublicKey func() key.DiscoPublic
|
||||
|
||||
// PeerAPIPort, if non-nil, returns the peerapi port that's
|
||||
// running for the given IP address.
|
||||
PeerAPIPort func(netip.Addr) (port uint16, ok bool)
|
||||
@ -211,8 +215,8 @@ type Wrapper struct {
|
||||
|
||||
metrics *metrics
|
||||
|
||||
eventClient *eventbus.Client
|
||||
discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement]
|
||||
eventClient *eventbus.Client
|
||||
discoKeyUpdatePub *eventbus.Publisher[DiscoKeyUpdate]
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
@ -227,6 +231,12 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
}
|
||||
}
|
||||
|
||||
// DiscoKeyUpdate is published on the event bus when a TSMP disco key update is received.
|
||||
type DiscoKeyUpdate struct {
|
||||
SrcIP netip.Addr
|
||||
Key [32]byte
|
||||
}
|
||||
|
||||
// tunInjectedRead is an injected packet pretending to be a tun.Read().
|
||||
type tunInjectedRead struct {
|
||||
// Only one of packet or data should be set, and are read in that order of
|
||||
@ -288,7 +298,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry,
|
||||
}
|
||||
|
||||
w.eventClient = bus.Client("net.tstun")
|
||||
w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient)
|
||||
w.discoKeyUpdatePub = eventbus.Publish[DiscoKeyUpdate](w.eventClient)
|
||||
|
||||
w.vectorBuffer = make([][]byte, tdev.BatchSize())
|
||||
for i := range w.vectorBuffer {
|
||||
@ -1126,11 +1136,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
|
||||
return n, err
|
||||
}
|
||||
|
||||
type DiscoKeyAdvertisement struct {
|
||||
Src netip.Addr
|
||||
Key key.DiscoPublic
|
||||
}
|
||||
|
||||
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
|
||||
if captHook != nil {
|
||||
captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
|
||||
@ -1141,16 +1146,21 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
|
||||
t.noteActivity()
|
||||
t.injectOutboundPong(p, pingReq)
|
||||
return filter.DropSilently, gro
|
||||
} else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok {
|
||||
t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{
|
||||
Src: discoKeyAdvert.Src,
|
||||
Key: discoKeyAdvert.Key,
|
||||
})
|
||||
return filter.DropSilently, gro
|
||||
} else if data, ok := p.AsTSMPPong(); ok {
|
||||
if f := t.OnTSMPPongReceived; f != nil {
|
||||
f(data)
|
||||
}
|
||||
} else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
|
||||
t.noteActivity()
|
||||
t.injectOutboundDiscoKeyUpdate(p)
|
||||
return filter.DropSilently, gro
|
||||
} else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
|
||||
// Publish to eventbus for subscribers
|
||||
t.discoKeyUpdatePub.Publish(DiscoKeyUpdate{
|
||||
SrcIP: p.Src.Addr(),
|
||||
Key: discoKeyUpdate.DiscoKey,
|
||||
})
|
||||
return filter.DropSilently, gro
|
||||
}
|
||||
}
|
||||
|
||||
@ -1459,6 +1469,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
|
||||
t.InjectOutbound(packet.Generate(pong, nil))
|
||||
}
|
||||
|
||||
func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
|
||||
if t.GetDiscoPublicKey == nil {
|
||||
return
|
||||
}
|
||||
|
||||
discoKey := t.GetDiscoPublicKey()
|
||||
if discoKey.IsZero() {
|
||||
return
|
||||
}
|
||||
|
||||
update := packet.TSMPDiscoKeyUpdate{
|
||||
DiscoKey: discoKey.Raw32(),
|
||||
}
|
||||
|
||||
switch pp.IPVersion {
|
||||
case 4:
|
||||
h4 := pp.IP4Header()
|
||||
h4.ToResponse()
|
||||
update.IPHeader = h4
|
||||
case 6:
|
||||
h6 := pp.IP6Header()
|
||||
h6.ToResponse()
|
||||
update.IPHeader = h6
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
t.InjectOutbound(packet.Generate(update, nil))
|
||||
}
|
||||
|
||||
// InjectOutbound makes the Wrapper device behave as if a packet
|
||||
// with the given contents was sent to the network.
|
||||
// It does not block, but takes ownership of the packet.
|
||||
|
||||
@ -966,28 +966,57 @@ func TestCaptureHook(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTSMPDisco(t *testing.T) {
|
||||
t.Run("IPv6DiscoAdvert", func(t *testing.T) {
|
||||
t.Run("DiscoKeyRequest", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("2001:db8::1")
|
||||
dst := netip.MustParseAddr("2001:db8::2")
|
||||
discoKey := key.NewDisco()
|
||||
buf, _ := (&packet.TSMPDiscoKeyAdvertisement{
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Key: discoKey.Public(),
|
||||
}).Marshal()
|
||||
|
||||
iph := packet.IP6Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
}
|
||||
|
||||
var payload [1]byte
|
||||
payload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
|
||||
buf := packet.Generate(iph, payload[:])
|
||||
|
||||
var p packet.Parsed
|
||||
p.Decode(buf)
|
||||
|
||||
tda, ok := p.AsTSMPDiscoAdvertisement()
|
||||
_, ok := p.AsTSMPDiscoKeyRequest()
|
||||
if !ok {
|
||||
t.Error("Unable to parse message as TSMPDiscoAdversitement")
|
||||
t.Error("Unable to parse message as TSMPDiscoKeyRequest")
|
||||
}
|
||||
if tda.Src != src {
|
||||
t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
|
||||
})
|
||||
|
||||
t.Run("DiscoKeyUpdate", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("2001:db8::1")
|
||||
dst := netip.MustParseAddr("2001:db8::2")
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
update := packet.TSMPDiscoKeyUpdate{
|
||||
IPHeader: packet.IP6Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
},
|
||||
DiscoKey: discoKey.Public().Raw32(),
|
||||
}
|
||||
if !reflect.DeepEqual(tda.Key, discoKey.Public()) {
|
||||
t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key)
|
||||
|
||||
buf := make([]byte, update.Len())
|
||||
if err := update.Marshal(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var p packet.Parsed
|
||||
p.Decode(buf)
|
||||
|
||||
parsed, ok := p.AsTSMPDiscoKeyUpdate()
|
||||
if !ok {
|
||||
t.Error("Unable to parse message as TSMPDiscoKeyUpdate")
|
||||
}
|
||||
if parsed.DiscoKey != discoKey.Public().Raw32() {
|
||||
t.Errorf("Key did not match, expected %v, got %v", discoKey.Public().Raw32(), parsed.DiscoKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
|
||||
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
|
||||
}
|
||||
|
||||
// Request disco key from peer via TSMP if we receive a WireGuard handshake
|
||||
// over DERP without recent disco success. This handles the "WireGuard-first"
|
||||
// case where WireGuard establishes a tunnel via DERP before disco succeeds
|
||||
// (e.g., control plane unreachable or stale disco keys).
|
||||
if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
|
||||
c.mu.Lock()
|
||||
c.requestDiscoKeyViaTSMPLocked(dm.src, ep)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
c.metrics.inboundPacketsDERPTotal.Add(1)
|
||||
c.metrics.inboundBytesDERPTotal.Add(int64(n))
|
||||
return n, ep
|
||||
|
||||
@ -178,9 +178,10 @@ type Conn struct {
|
||||
|
||||
// A publisher for synchronization points to ensure correct ordering of
|
||||
// config changes between magicsock and wireguard.
|
||||
syncPub *eventbus.Publisher[syncPoint]
|
||||
allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
|
||||
portUpdatePub *eventbus.Publisher[router.PortUpdate]
|
||||
syncPub *eventbus.Publisher[syncPoint]
|
||||
allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
|
||||
portUpdatePub *eventbus.Publisher[router.PortUpdate]
|
||||
tsmpDiscoKeyRequestPub *eventbus.Publisher[TSMPDiscoKeyRequest]
|
||||
|
||||
// pconn4 and pconn6 are the underlying UDP sockets used to
|
||||
// send/receive packets for wireguard and other magicsock
|
||||
@ -572,6 +573,14 @@ type UDPRelayAllocReq struct {
|
||||
Message *disco.AllocateUDPRelayEndpointRequest
|
||||
}
|
||||
|
||||
// TSMPDiscoKeyRequest is published on the event bus when magicsock needs to
|
||||
// send a TSMP disco key request to a peer. Subscribers should inject the
|
||||
// TSMP packet into the tunnel device.
|
||||
type TSMPDiscoKeyRequest struct {
|
||||
DstIP netip.Addr
|
||||
MetricSent *clientmetric.Metric
|
||||
}
|
||||
|
||||
// UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse]
|
||||
// that is yet to be transmitted over DERP (or delivered locally if
|
||||
// ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from
|
||||
@ -691,6 +700,7 @@ func NewConn(opts Options) (*Conn, error) {
|
||||
c.syncPub = eventbus.Publish[syncPoint](ec)
|
||||
c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec)
|
||||
c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec)
|
||||
c.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](ec)
|
||||
eventbus.SubscribeFunc(ec, c.onPortMapChanged)
|
||||
eventbus.SubscribeFunc(ec, c.onFilterUpdate)
|
||||
eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate)
|
||||
@ -1800,6 +1810,15 @@ func looksLikeInitiationMsg(b []byte) bool {
|
||||
binary.LittleEndian.Uint32(b) == device.MessageInitiationType
|
||||
}
|
||||
|
||||
func looksLikeWireGuardHandshake(b []byte) bool {
|
||||
if len(b) < 4 {
|
||||
return false
|
||||
}
|
||||
msgType := binary.LittleEndian.Uint32(b)
|
||||
return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
|
||||
(len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
|
||||
}
|
||||
|
||||
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
|
||||
//
|
||||
// size is the length of 'b' to report up to wireguard-go (only relevant if
|
||||
@ -4104,6 +4123,12 @@ var (
|
||||
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
|
||||
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
|
||||
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
|
||||
|
||||
// TSMP disco key exchange
|
||||
metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
|
||||
metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
|
||||
metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
|
||||
metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
|
||||
)
|
||||
|
||||
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
|
||||
@ -4242,6 +4267,81 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
|
||||
// See http://go/corp/29422 & http://go/corp/30042
|
||||
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
|
||||
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
|
||||
|
||||
le.c.requestDiscoKeyViaTSMPLocked(pubKey, ep)
|
||||
}
|
||||
|
||||
// requestDiscoKeyViaTSMPLocked sends a TSMP disco key request to a peer if there
|
||||
// hasn't been a recent disco ping.
|
||||
// c.mu must be held.
|
||||
func (c *Conn) requestDiscoKeyViaTSMPLocked(nodeKey key.NodePublic, ep *endpoint) {
|
||||
if !ep.nodeAddr.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
epDisco := ep.disco.Load()
|
||||
if epDisco != nil {
|
||||
di := c.discoInfo[epDisco.key]
|
||||
recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
|
||||
|
||||
if recentDiscoPing {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go c.tsmpDiscoKeyRequestPub.Publish(TSMPDiscoKeyRequest{DstIP: ep.nodeAddr, MetricSent: metricTSMPDiscoKeyRequestSent})
|
||||
}
|
||||
|
||||
// HandleDiscoKeyUpdate processes a TSMP disco key update.
|
||||
// The update may be solicited (in response to a request) or unsolicited.
|
||||
// srcIP is the Tailscale IP of the peer that sent the update.
|
||||
func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
|
||||
discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
|
||||
c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
|
||||
metricTSMPDiscoKeyUpdateReceived.Add(1)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
var found bool
|
||||
for _, peer := range c.peers.All() {
|
||||
for _, addr := range peer.Addresses().All() {
|
||||
if addr.Addr() == srcIP {
|
||||
nodeKey = peer.Key()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
c.logf("magicsock: disco key update from unknown peer %v", srcIP)
|
||||
metricTSMPDiscoKeyUpdateUnknown.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
|
||||
if !ok {
|
||||
c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
|
||||
return
|
||||
}
|
||||
|
||||
oldDiscoKey := key.DiscoPublic{}
|
||||
if epDisco := ep.disco.Load(); epDisco != nil {
|
||||
oldDiscoKey = epDisco.key
|
||||
}
|
||||
c.discoInfoForKnownPeerLocked(discoKey)
|
||||
ep.disco.Store(&endpointDisco{
|
||||
key: discoKey,
|
||||
short: discoKey.ShortString(),
|
||||
})
|
||||
c.peerMap.upsertEndpoint(ep, oldDiscoKey)
|
||||
c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
|
||||
metricTSMPDiscoKeyUpdateApplied.Add(1)
|
||||
}
|
||||
|
||||
// PeerRelays returns the current set of candidate peer relays.
|
||||
|
||||
@ -64,6 +64,7 @@ import (
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/cibuild"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/eventbus"
|
||||
@ -4302,3 +4303,66 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
|
||||
keys = append(keys, newKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTSMPDiscoKeyRequest(t *testing.T) {
|
||||
ep := &endpoint{
|
||||
nodeID: 1,
|
||||
publicKey: key.NewNode().Public(),
|
||||
nodeAddr: netip.MustParseAddr("100.64.0.1"),
|
||||
}
|
||||
discoKey := key.NewDisco().Public()
|
||||
ep.disco.Store(&endpointDisco{
|
||||
key: discoKey,
|
||||
short: discoKey.ShortString(),
|
||||
})
|
||||
bus := eventbustest.NewBus(t)
|
||||
conn := newConn(t.Logf)
|
||||
conn.eventBus = bus
|
||||
conn.eventClient = bus.Client("magicsock.Conn.test")
|
||||
conn.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](conn.eventClient)
|
||||
ep.c = conn
|
||||
|
||||
tsmpRequestCalled := make(chan struct{}, 1)
|
||||
var capturedIP netip.Addr
|
||||
ec := bus.Client("test")
|
||||
defer ec.Close()
|
||||
eventbus.SubscribeFunc(ec, func(req TSMPDiscoKeyRequest) {
|
||||
capturedIP = req.DstIP
|
||||
if req.MetricSent != nil {
|
||||
req.MetricSent.Add(1)
|
||||
}
|
||||
tsmpRequestCalled <- struct{}{}
|
||||
})
|
||||
|
||||
conn.mu.Lock()
|
||||
conn.peers = views.SliceOf([]tailcfg.NodeView{
|
||||
(&tailcfg.Node{
|
||||
Key: ep.publicKey,
|
||||
Addresses: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
},
|
||||
}).View(),
|
||||
})
|
||||
conn.mu.Unlock()
|
||||
|
||||
var pubKey [32]byte
|
||||
copy(pubKey[:], ep.publicKey.AppendTo(nil))
|
||||
conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
|
||||
|
||||
le := &lazyEndpoint{
|
||||
c: conn,
|
||||
src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
|
||||
}
|
||||
|
||||
le.FromPeer(pubKey)
|
||||
|
||||
select {
|
||||
case <-tsmpRequestCalled:
|
||||
if !capturedIP.IsValid() {
|
||||
t.Error("TSMP request sent with invalid IP")
|
||||
}
|
||||
t.Logf("TSMP disco key request sent to %v", capturedIP)
|
||||
case <-time.After(time.Second):
|
||||
t.Error("TSMP disco key request was not sent")
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,6 +54,7 @@ import (
|
||||
"tailscale.com/util/execqueue"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
"tailscale.com/util/singleflight"
|
||||
"tailscale.com/util/testenv"
|
||||
"tailscale.com/util/usermetric"
|
||||
"tailscale.com/version"
|
||||
@ -469,6 +470,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
return true
|
||||
}
|
||||
|
||||
e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
|
||||
if e.magicConn == nil {
|
||||
return key.DiscoPublic{}
|
||||
}
|
||||
return e.magicConn.DiscoPublicKey()
|
||||
}
|
||||
|
||||
// wgdev takes ownership of tundev, will close it when closed.
|
||||
e.logf("Creating WireGuard device...")
|
||||
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
|
||||
@ -549,6 +557,36 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
}
|
||||
e.linkChangeQueue.Add(func() { e.linkChange(&cd) })
|
||||
})
|
||||
eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyUpdate) {
|
||||
e.logf("wgengine: got TSMP disco key update from %v via eventbus", update.SrcIP)
|
||||
if e.magicConn != nil {
|
||||
pkt := packet.TSMPDiscoKeyUpdate{
|
||||
DiscoKey: update.Key,
|
||||
}
|
||||
e.magicConn.HandleDiscoKeyUpdate(update.SrcIP, pkt)
|
||||
}
|
||||
})
|
||||
var tsmpRequestGroup singleflight.Group[netip.Addr, struct{}]
|
||||
eventbus.SubscribeFunc(ec, func(req magicsock.TSMPDiscoKeyRequest) {
|
||||
go tsmpRequestGroup.Do(req.DstIP, func() (struct{}, error) {
|
||||
// DiscoKeyRequests are triggered by an incoming WireGuard handshake
|
||||
// initiation arriving before a disco ping, which is a likely
|
||||
// indicator that disco pings failed due to a lack of key
|
||||
// synchronization. If the requests are sent immediately, before the
|
||||
// handshake state is accepted in the WireGuard client state
|
||||
// machine, this starts a new session, and the two peer state
|
||||
// machines conflict, causing loss and additional delays. Delaying
|
||||
// the send avoids this, so coalesce duplicate sends, and delay them
|
||||
// by a short time to avoid the state machine conflict.
|
||||
time.Sleep(time.Millisecond)
|
||||
if err := e.sendTSMPDiscoKeyRequest(req.DstIP); err != nil {
|
||||
e.logf("wgengine: failed to send TSMP disco key request: %v", err)
|
||||
}
|
||||
e.logf("wgengine: sending TSMP disco key request to %v", req.DstIP)
|
||||
req.MetricSent.Add(1)
|
||||
return struct{}{}, nil
|
||||
})
|
||||
})
|
||||
e.eventClient = ec
|
||||
e.logf("Engine created.")
|
||||
return e, nil
|
||||
@ -1436,7 +1474,6 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in
|
||||
e.magicConn.Ping(peer, res, size, cb)
|
||||
case "TSMP":
|
||||
e.sendTSMPPing(ip, peer, res, cb)
|
||||
e.sendTSMPDiscoAdvertisement(ip)
|
||||
case "ICMP":
|
||||
e.sendICMPEchoRequest(ip, peer, res, cb)
|
||||
}
|
||||
@ -1557,29 +1594,6 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res
|
||||
e.tundev.InjectOutbound(tsmpPing)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) {
|
||||
srcIP, err := e.mySelfIPMatchingFamily(ip)
|
||||
if err != nil {
|
||||
e.logf("getting matching node: %s", err)
|
||||
return
|
||||
}
|
||||
tdka := packet.TSMPDiscoKeyAdvertisement{
|
||||
Src: srcIP,
|
||||
Dst: ip,
|
||||
Key: e.magicConn.DiscoPublicKey(),
|
||||
}
|
||||
payload, err := tdka.Marshal()
|
||||
if err != nil {
|
||||
e.logf("error generating TSMP Advertisement: %s", err)
|
||||
metricTSMPDiscoKeyAdvertisementError.Add(1)
|
||||
} else if err := e.tundev.InjectOutbound(payload); err != nil {
|
||||
e.logf("error sending TSMP Advertisement: %s", err)
|
||||
metricTSMPDiscoKeyAdvertisementError.Add(1)
|
||||
} else {
|
||||
metricTSMPDiscoKeyAdvertisementSent.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
@ -1593,6 +1607,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
|
||||
}
|
||||
}
|
||||
|
||||
// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
|
||||
func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
|
||||
srcIP, err := e.mySelfIPMatchingFamily(ip)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var iph packet.Header
|
||||
if srcIP.Is4() {
|
||||
iph = packet.IP4Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: srcIP,
|
||||
Dst: ip,
|
||||
}
|
||||
} else {
|
||||
iph = packet.IP6Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: srcIP,
|
||||
Dst: ip,
|
||||
}
|
||||
}
|
||||
|
||||
var tsmpPayload [1]byte
|
||||
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
|
||||
|
||||
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
|
||||
return e.tundev.InjectOutbound(tsmpRequest)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
@ -1746,9 +1789,6 @@ var (
|
||||
|
||||
metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes")
|
||||
metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes")
|
||||
|
||||
metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent")
|
||||
metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error")
|
||||
)
|
||||
|
||||
func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) {
|
||||
|
||||
@ -325,7 +325,7 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTSMPKeyAdvertisement(t *testing.T) {
|
||||
func TestTSMPDiscoKeyRequest(t *testing.T) {
|
||||
var knobs controlknobs.Knobs
|
||||
|
||||
bus := eventbustest.NewBus(t)
|
||||
@ -369,13 +369,12 @@ func TestTSMPKeyAdvertisement(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddr("100.100.99.1")
|
||||
previousValue := metricTSMPDiscoKeyAdvertisementSent.Value()
|
||||
ue.sendTSMPDiscoAdvertisement(addr)
|
||||
if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue {
|
||||
errs := metricTSMPDiscoKeyAdvertisementError.Value()
|
||||
t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs)
|
||||
peerAddr := netip.MustParseAddr("100.100.99.1")
|
||||
err = ue.sendTSMPDiscoKeyRequest(peerAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("sendTSMPDiscoKeyRequest failed: %v", err)
|
||||
}
|
||||
|
||||
// Remove config to have the engine shut down more consistently
|
||||
err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{})
|
||||
if err != nil {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user