wgengine/magicsock: add webrtc path to magicsock (experimental)

This commit is contained in:
Adriano Sela Aviles 2026-04-11 12:01:09 -07:00 committed by Adriano Sela Aviles
parent 4ce1643929
commit 413ba38632
No known key found for this signature in database
GPG Key ID: 28128631BCCBB1BB
19 changed files with 2858 additions and 26 deletions

View File

@ -23,6 +23,7 @@ import (
"tailscale.com/ipn"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/netmon"
"tailscale.com/tailcfg"
"tailscale.com/util/dnsname"
)
@ -196,7 +197,19 @@ func runStatus(ctx context.Context, args []string) error {
if relay != "" && ps.CurAddr == "" && ps.PeerRelay == "" {
f("relay %q", relay)
} else if ps.CurAddr != "" {
f("direct %s", ps.CurAddr)
// Check if this is a WebRTC connection (address matches WebRTC magic IP)
if strings.HasPrefix(ps.CurAddr, tailcfg.WebRTCMagicIP) {
// Extract the actual remote address from CurAddr, which for a WebRTC path
// is of the form "${WEBRTC_MAGIC_IP}:${DUMMY_PORT} (${REAL_IP_AND_PORT})"
// e.g. "127.3.3.41:12345 (134.209.53.229:37792)".
realRemoteAddr := "<UNKNOWN>"
if idx := strings.Index(ps.CurAddr, " ("); idx > 0 {
realRemoteAddr = ps.CurAddr[idx+2 : len(ps.CurAddr)-1]
}
f("webrtc %s", realRemoteAddr)
} else {
f("direct %s", ps.CurAddr)
}
} else if ps.PeerRelay != "" {
f("peer-relay %s", ps.PeerRelay)
}

View File

@ -0,0 +1,5 @@
module github.com/adrianosela/tailscale/example-webrtc-server
go 1.26
require github.com/gorilla/websocket v1.5.3

View File

@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

View File

@ -0,0 +1,466 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
// example-webrtc-server is a WebRTC signaling server that supports both
// WebSocket (for Tailscale) and HTTP REST (for standard WebRTC clients).
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// SignalingMessage represents a WebRTC signaling message.
// This format is compatible with both Tailscale and standard WebRTC clients.
type SignalingMessage struct {
Type string `json:"type"` // "offer", "answer", "candidate"
From string `json:"from"` // sender's disco public key (hex)
To string `json:"to"` // recipient's disco public key (hex)
// For SDP offer/answer (raw JSON for flexibility)
Offer json.RawMessage `json:"offer,omitempty"`
Answer json.RawMessage `json:"answer,omitempty"`
Candidate json.RawMessage `json:"candidate,omitempty"`
// Legacy fields for HTTP REST clients
SDP string `json:"sdp,omitempty"` // Used by non-Tailscale clients
Timestamp time.Time `json:"timestamp"`
}
// Client represents a connected peer (WebSocket or HTTP polling)
type Client struct {
ID string
Conn *websocket.Conn // nil for HTTP clients
LastSeen time.Time
}
// SignalingServer manages WebRTC signaling between peers
type SignalingServer struct {
mu sync.RWMutex
clients map[string]*Client // Active WebSocket clients
// Message queue for HTTP polling clients
messages map[string][]SignalingMessage // Key: "to" peer ID
upgrader websocket.Upgrader
// Statistics
stats struct {
totalMessages int
wsConnections int
httpPolls int
activeOffers int
completedPairs int
}
}
// NewSignalingServer creates a new signaling server
func NewSignalingServer() *SignalingServer {
return &SignalingServer{
clients: make(map[string]*Client),
messages: make(map[string][]SignalingMessage),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins (configure as needed)
},
},
}
}
// RouteMessage delivers a message to the destination peer
func (s *SignalingServer) RouteMessage(msg SignalingMessage, clientIP string) error {
msg.Timestamp = time.Now()
s.mu.Lock()
defer s.mu.Unlock()
// Update stats
s.stats.totalMessages++
switch msg.Type {
case "offer":
s.stats.activeOffers++
// Clear old messages from this sender when starting a new session
s.clearOldMessages(msg.From, msg.To)
case "answer":
s.stats.completedPairs++
if s.stats.activeOffers > 0 {
s.stats.activeOffers--
}
}
log.Printf("[%s] Routing %s from %s to %s", clientIP, msg.Type, msg.From, msg.To)
// Try to deliver to WebSocket client first
if client, ok := s.clients[msg.To]; ok && client.Conn != nil {
// Send directly via WebSocket
if err := client.Conn.WriteJSON(msg); err != nil {
log.Printf("[%s] Failed to send to WebSocket client %s: %v", clientIP, msg.To, err)
// Remove dead connection
delete(s.clients, msg.To)
// Fall through to queue message
} else {
log.Printf("[%s] Delivered %s to WebSocket client %s", clientIP, msg.Type, msg.To)
return nil
}
}
// Queue for HTTP polling client
s.messages[msg.To] = append(s.messages[msg.To], msg)
log.Printf("[%s] Queued %s for HTTP client %s", clientIP, msg.Type, msg.To)
return nil
}
// clearOldMessages removes previous messages between two peers (used when new session starts)
func (s *SignalingServer) clearOldMessages(from, to string) {
if msgs, ok := s.messages[to]; ok {
filtered := make([]SignalingMessage, 0)
for _, msg := range msgs {
if msg.From != from {
filtered = append(filtered, msg)
}
}
if len(filtered) == 0 {
delete(s.messages, to)
} else {
s.messages[to] = filtered
}
}
}
// GetMessages retrieves queued messages for an HTTP polling client
func (s *SignalingServer) GetMessages(peerID, clientIP string) []SignalingMessage {
s.mu.Lock()
defer s.mu.Unlock()
s.stats.httpPolls++
messages := s.messages[peerID]
if len(messages) == 0 {
return nil
}
// Return all messages and clear the queue
delete(s.messages, peerID)
log.Printf("[%s] Delivering %d queued message(s) to HTTP client %s", clientIP, len(messages), peerID)
return messages
}
// CleanupOldMessages removes stale messages and dead connections
func (s *SignalingServer) CleanupOldMessages(maxAge time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
cutoff := time.Now().Add(-maxAge)
cleaned := 0
// Clean old messages
for peerID, messages := range s.messages {
filtered := make([]SignalingMessage, 0, len(messages))
for _, msg := range messages {
if msg.Timestamp.After(cutoff) {
filtered = append(filtered, msg)
} else {
cleaned++
}
}
if len(filtered) == 0 {
delete(s.messages, peerID)
} else {
s.messages[peerID] = filtered
}
}
// Clean inactive clients
for id, client := range s.clients {
if time.Since(client.LastSeen) > maxAge {
if client.Conn != nil {
client.Conn.Close()
}
delete(s.clients, id)
cleaned++
}
}
if cleaned > 0 {
log.Printf("Cleaned up %d old messages/connections", cleaned)
}
}
// GetStats returns current server statistics
func (s *SignalingServer) GetStats() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
queuedMessages := 0
for _, msgs := range s.messages {
queuedMessages += len(msgs)
}
return map[string]interface{}{
"total_messages": s.stats.totalMessages,
"ws_connections": s.stats.wsConnections,
"http_polls": s.stats.httpPolls,
"active_offers": s.stats.activeOffers,
"completed_pairs": s.stats.completedPairs,
"queued_messages": queuedMessages,
"active_ws_clients": len(s.clients),
"active_peer_ids": len(s.messages),
}
}
// WebSocket Handler (for Tailscale clients)
func (s *SignalingServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket upgrade failed: %v", err)
return
}
defer conn.Close()
var clientID string
s.mu.Lock()
s.stats.wsConnections++
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.stats.wsConnections--
s.mu.Unlock()
}()
log.Printf("[%s] New WebSocket connection", r.RemoteAddr)
// Read and route messages from this WebSocket client
for {
var msg SignalingMessage
if err := conn.ReadJSON(&msg); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("[%s] WebSocket error: %v", r.RemoteAddr, err)
}
break
}
// Register client on first message
if clientID == "" {
clientID = msg.From
s.mu.Lock()
s.clients[clientID] = &Client{
ID: clientID,
Conn: conn,
LastSeen: time.Now(),
}
s.mu.Unlock()
log.Printf("[%s] WebSocket client registered as %s", r.RemoteAddr, clientID)
}
// Update last seen
s.mu.Lock()
if client, ok := s.clients[clientID]; ok {
client.LastSeen = time.Now()
}
s.mu.Unlock()
// Route the message to destination
if err := s.RouteMessage(msg, r.RemoteAddr); err != nil {
log.Printf("[%s] Failed to route message: %v", r.RemoteAddr, err)
}
}
// Cleanup on disconnect
if clientID != "" {
s.mu.Lock()
delete(s.clients, clientID)
s.mu.Unlock()
log.Printf("[%s] WebSocket client %s disconnected", r.RemoteAddr, clientID)
}
}
// HTTP REST Handlers (for standard WebRTC clients)
func (s *SignalingServer) handlePostSignal(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var msg SignalingMessage
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if msg.From == "" || msg.To == "" || msg.Type == "" {
http.Error(w, "Missing required fields: from, to, type", http.StatusBadRequest)
return
}
if msg.Type != "offer" && msg.Type != "answer" && msg.Type != "candidate" {
http.Error(w, "Invalid type, must be 'offer', 'answer', or 'candidate'", http.StatusBadRequest)
return
}
// Convert legacy SDP field to Offer/Answer format for compatibility
if msg.SDP != "" {
sdpJSON := json.RawMessage(fmt.Sprintf(`{"type":"%s","sdp":%q}`, msg.Type, msg.SDP))
if msg.Type == "offer" {
msg.Offer = sdpJSON
} else if msg.Type == "answer" {
msg.Answer = sdpJSON
}
}
if err := s.RouteMessage(msg, r.RemoteAddr); err != nil {
http.Error(w, fmt.Sprintf("Failed to route message: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{
"status": "ok",
"message": fmt.Sprintf("Message routed to %s", msg.To),
})
}
func (s *SignalingServer) handleGetSignal(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
to := r.URL.Query().Get("to")
if to == "" {
http.Error(w, "Missing required query parameter: to", http.StatusBadRequest)
return
}
messages := s.GetMessages(to, r.RemoteAddr)
w.Header().Set("Content-Type", "application/json")
if len(messages) == 0 {
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]string{
"status": "not_found",
"message": fmt.Sprintf("No messages for %s", to),
})
return
}
w.WriteHeader(http.StatusOK)
// Return first message (client should poll again for more)
json.NewEncoder(w).Encode(messages[0])
}
func (s *SignalingServer) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(s.GetStats())
}
func (s *SignalingServer) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"status": "healthy",
"service": "webrtc-signaling-server",
"version": "1.0.0",
})
}
// CORS middleware
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
next(w, r)
}
}
// Logging middleware
func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next(w, r)
log.Printf("%s %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, time.Since(start))
}
}
func main() {
port := flag.Int("port", 8080, "Port to listen on")
tlsCert := flag.String("cert", "", "TLS certificate file (optional, for HTTPS)")
tlsKey := flag.String("key", "", "TLS key file (optional, for HTTPS)")
cleanupInterval := flag.Duration("cleanup", 5*time.Minute, "Interval for cleaning up old messages")
messageMaxAge := flag.Duration("max-age", 10*time.Minute, "Maximum age for messages before cleanup")
flag.Parse()
server := NewSignalingServer()
// Start cleanup goroutine
go func() {
ticker := time.NewTicker(*cleanupInterval)
defer ticker.Stop()
for range ticker.C {
server.CleanupOldMessages(*messageMaxAge)
}
}()
// Register handlers
// WebSocket endpoint (for Tailscale)
http.HandleFunc("/ws", loggingMiddleware(server.handleWebSocket))
// HTTP REST endpoints (for standard WebRTC clients)
http.HandleFunc("/signal", corsMiddleware(loggingMiddleware(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
server.handlePostSignal(w, r)
} else if r.Method == http.MethodGet {
server.handleGetSignal(w, r)
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
})))
// Monitoring endpoints
http.HandleFunc("/stats", corsMiddleware(loggingMiddleware(server.handleStats)))
http.HandleFunc("/health", corsMiddleware(loggingMiddleware(server.handleHealth)))
addr := fmt.Sprintf(":%d", *port)
log.Printf("Server starting on %s", addr)
log.Printf("WebSocket endpoint: ws://localhost%s/ws", addr)
log.Printf("HTTP REST endpoint: http://localhost%s/signal", addr)
log.Printf("Cleanup interval: %v, Max message age: %v", *cleanupInterval, *messageMaxAge)
log.Println("────────────────────────────────────────────────────────────")
var err error
if *tlsCert != "" && *tlsKey != "" {
log.Printf("Starting HTTPS server with TLS...")
err = http.ListenAndServeTLS(addr, *tlsCert, *tlsKey, nil)
} else {
log.Printf("Starting HTTP server (use -cert and -key for HTTPS)...")
err = http.ListenAndServe(addr, nil)
}
if err != nil {
log.Fatal(err)
}
}

19
go.mod
View File

@ -57,6 +57,7 @@ require (
github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806
github.com/google/uuid v1.6.0
github.com/goreleaser/nfpm/v2 v2.33.1
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674
github.com/hashicorp/go-hclog v1.6.2
github.com/hashicorp/raft v1.7.2
github.com/hashicorp/raft-boltdb/v2 v2.3.1
@ -79,6 +80,7 @@ require (
github.com/miekg/dns v1.1.58
github.com/mitchellh/go-ps v1.0.0
github.com/peterbourgon/ff/v3 v3.4.0
github.com/pion/webrtc/v4 v4.0.0
github.com/pires/go-proxyproto v0.8.1
github.com/pkg/errors v0.9.1
github.com/pkg/sftp v1.13.6
@ -198,7 +200,6 @@ require (
github.com/google/renameio/v2 v2.0.0 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
github.com/gosuri/uitable v0.0.4 // indirect
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
@ -230,6 +231,21 @@ require (
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
github.com/pelletier/go-toml v1.9.5 // indirect
github.com/peterbourgon/diskv v2.0.1+incompatible // indirect
github.com/pion/datachannel v1.5.9 // indirect
github.com/pion/dtls/v3 v3.0.3 // indirect
github.com/pion/ice/v4 v4.0.2 // indirect
github.com/pion/interceptor v0.1.37 // indirect
github.com/pion/logging v0.2.2 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.14 // indirect
github.com/pion/rtp v1.8.9 // indirect
github.com/pion/sctp v1.8.33 // indirect
github.com/pion/sdp/v3 v3.0.9 // indirect
github.com/pion/srtp/v3 v3.0.4 // indirect
github.com/pion/stun/v3 v3.0.0 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pion/turn/v4 v4.0.0 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/puzpuzpuz/xsync v1.5.2 // indirect
github.com/rtr7/dhcp4 v0.0.0-20220302171438-18c84d089b46 // indirect
@ -239,6 +255,7 @@ require (
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
github.com/stacklok/frizbee v0.1.7 // indirect
github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
github.com/xen0n/gosmopolitan v1.2.2 // indirect
github.com/xlab/treeprint v1.2.0 // indirect
github.com/ykadowak/zerologlint v0.1.5 // indirect

34
go.sum
View File

@ -948,6 +948,38 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1H
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0=
github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4=
github.com/pion/datachannel v1.5.9 h1:LpIWAOYPyDrXtU+BW7X0Yt/vGtYxtXQ8ql7dFfYUVZA=
github.com/pion/datachannel v1.5.9/go.mod h1:kDUuk4CU4Uxp82NH4LQZbISULkX/HtzKa4P7ldf9izE=
github.com/pion/dtls/v3 v3.0.3 h1:j5ajZbQwff7Z8k3pE3S+rQ4STvKvXUdKsi/07ka+OWM=
github.com/pion/dtls/v3 v3.0.3/go.mod h1:weOTUyIV4z0bQaVzKe8kpaP17+us3yAuiQsEAG1STMU=
github.com/pion/ice/v4 v4.0.2 h1:1JhBRX8iQLi0+TfcavTjPjI6GO41MFn4CeTBX+Y9h5s=
github.com/pion/ice/v4 v4.0.2/go.mod h1:DCdqyzgtsDNYN6/3U8044j3U7qsJ9KFJC92VnOWHvXg=
github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI=
github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y=
github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE=
github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4=
github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk=
github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU=
github.com/pion/sctp v1.8.33 h1:dSE4wX6uTJBcNm8+YlMg7lw1wqyKHggsP5uKbdj+NZw=
github.com/pion/sctp v1.8.33/go.mod h1:beTnqSzewI53KWoG3nqB282oDMGrhNxBdb+JZnkCwRM=
github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY=
github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M=
github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M=
github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM=
github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA=
github.com/pion/webrtc/v4 v4.0.0 h1:x8ec7uJQPP3D1iI8ojPAiTOylPI7Fa7QgqZrhpLyqZ8=
github.com/pion/webrtc/v4 v4.0.0/go.mod h1:SfNn8CcFxR6OUVjLXVslAQ3a3994JhyE3Hw1jAuqEto=
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4=
@ -1209,6 +1241,8 @@ github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=

View File

@ -651,7 +651,18 @@ table tbody tr:nth-child(even) td { background-color: #f5f5f5; }
if ps.Relay != "" && ps.CurAddr == "" {
f("relay <b>%s</b>", html.EscapeString(ps.Relay))
} else if ps.CurAddr != "" {
f("direct <b>%s</b>", html.EscapeString(ps.CurAddr))
// Check if this is a WebRTC connection (magic IP 127.3.3.41)
if strings.HasPrefix(ps.CurAddr, "127.3.3.41:") {
// Extract the actual remote address if present
if idx := strings.Index(ps.CurAddr, " ("); idx > 0 {
remoteAddr := ps.CurAddr[idx+2 : len(ps.CurAddr)-1] // Extract address from " (addr)"
f("webrtc <b>%s</b>", html.EscapeString(remoteAddr))
} else {
f("webrtc")
}
} else {
f("direct <b>%s</b>", html.EscapeString(ps.CurAddr))
}
}
}

53
rtclib/signaling.go Normal file
View File

@ -0,0 +1,53 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package rtclib
import "github.com/pion/webrtc/v4"
// Signaling message types.
const (
MessageTypeOffer = "offer"
MessageTypeAnswer = "answer"
MessageTypeCandidate = "candidate"
)
// SignalingMessage represents a message exchanged over the signaling channel.
type SignalingMessage struct {
Type string `json:"type"` // "offer", "answer", "candidate"
From string `json:"from"` // sender's disco public key (hex)
To string `json:"to"` // recipient's disco public key (hex)
Offer *webrtc.SessionDescription `json:"offer,omitempty"`
Answer *webrtc.SessionDescription `json:"answer,omitempty"`
Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"`
}
// SignalHandler defines callbacks for handling incoming signaling messages.
type SignalHandler interface {
// HandleOffer is called when an offer is received from a peer.
HandleOffer(from, to string, offer *webrtc.SessionDescription)
// HandleAnswer is called when an answer is received from a peer.
HandleAnswer(from, to string, answer *webrtc.SessionDescription)
// HandleCandidate is called when an ICE candidate is received from a peer.
HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit)
}
// Signaller defines the interface for WebRTC signaling implementations.
type Signaller interface {
// Start begins the signaling connection with the provided handler.
Start(handler SignalHandler) error
// Offer sends an SDP offer to a peer.
Offer(from, to string, offer *webrtc.SessionDescription) error
// Answer sends an SDP answer to a peer.
Answer(from, to string, answer *webrtc.SessionDescription) error
// Candidate sends an ICE candidate to a peer.
Candidate(from, to string, candidate *webrtc.ICECandidateInit) error
// Close shuts down the signaling connection.
Close() error
}

View File

@ -3292,6 +3292,14 @@ const DerpMagicIP = "127.3.3.40"
var DerpMagicIPAddr = netip.MustParseAddr(DerpMagicIP)
// WebRTCMagicIP is a fake WireGuard endpoint IP address that means
// to use WebRTC data channel for packet transmission.
//
// Mnemonic: 127.3.3.41 is one above DerpMagicIP for WebRTC.
const WebRTCMagicIP = "127.3.3.41"
var WebRTCMagicIPAddr = netip.MustParseAddr(WebRTCMagicIP)
// EarlyNoise is the early payload that's sent over Noise but before the HTTP/2
// handshake when connecting to the coordination server.
//

View File

@ -66,6 +66,13 @@ var (
// suppressing/dropping inbound/outbound [disco.Ping] messages, forcing
// all peer communication over DERP or peer relay.
debugNeverDirectUDP = envknob.RegisterBool("TS_DEBUG_NEVER_DIRECT_UDP")
// debugWebRTCSignalingURL sets the WebRTC signaling server URL for
// establishing WebRTC peer connections. When set, magicsock will attempt
// to use WebRTC as an additional path for peer communication.
debugWebRTCSignalingURL = envknob.RegisterString("TS_DEBUG_WEBRTC_SIGNALING_URL")
// debugAlwaysWebRTC forces all peer communication over WebRTC by
// suppressing disco pings to direct UDP and DERP addresses.
debugAlwaysWebRTC = envknob.RegisterBool("TS_DEBUG_ALWAYS_USE_WEBRTC")
// Hey you! Adding a new debugknob? Make sure to stub it out in the
// debugknobs_stubs.go file too.
)

View File

@ -7,6 +7,7 @@ package magicsock
import (
"net/netip"
"os"
"tailscale.com/types/opt"
)
@ -32,3 +33,5 @@ func inTest() bool { return false }
func debugPeerMap() bool { return false }
func pretendpoints() []netip.AddrPort { return []netip.AddrPort{} }
func debugNeverDirectUDP() bool { return false }
func debugWebRTCSignalingURL() string { return os.Getenv("TS_DEBUG_WEBRTC_SIGNALING_URL") }
func debugAlwaysWebRTC() bool { return false }

View File

@ -1076,7 +1076,13 @@ func (de *endpoint) send(buffs [][]byte, offset int) error {
}
}
var err error
if udpAddr.ap.IsValid() {
// Check if this is a WebRTC address and route accordingly
if udpAddr.ap.IsValid() && udpAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr {
// Pack all buffs into one SCTP message. See sendWebRTCBatch for why.
if err = de.c.sendWebRTCBatch(de.publicKey, buffs, offset); err != nil {
return err
}
} else if udpAddr.ap.IsValid() {
_, err = de.c.sendUDPBatch(udpAddr, buffs, offset)
// If the error is known to indicate that the endpoint is no longer
@ -1295,6 +1301,9 @@ func (de *endpoint) startDiscoPingLocked(ep epAddr, now mono.Time, purpose disco
if debugNeverDirectUDP() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.DerpMagicIPAddr {
return
}
if debugAlwaysWebRTC() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.WebRTCMagicIPAddr {
return
}
epDisco := de.disco.Load()
if epDisco == nil {
return
@ -1815,10 +1824,13 @@ type epAddr struct {
vni packet.VirtualNetworkID // vni.IsSet() indicates if this [epAddr] involves a Geneve header
}
// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr,
// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr or WebRTCMagicIPAddr,
// and a VNI is not set.
func (e epAddr) isDirect() bool {
return e.ap.IsValid() && e.ap.Addr() != tailcfg.DerpMagicIPAddr && !e.vni.IsSet()
return e.ap.IsValid() &&
e.ap.Addr() != tailcfg.DerpMagicIPAddr &&
e.ap.Addr() != tailcfg.WebRTCMagicIPAddr &&
!e.vni.IsSet()
}
func (e epAddr) String() string {
@ -1871,6 +1883,28 @@ func betterAddr(a, b addrQuality) bool {
return false
}
// WebRTC path priority: Direct UDP > WebRTC > Peer Relay/DERP
aIsWebRTC := a.ap.Addr() == tailcfg.WebRTCMagicIPAddr
bIsWebRTC := b.ap.Addr() == tailcfg.WebRTCMagicIPAddr
aIsDERP := a.ap.Addr() == tailcfg.DerpMagicIPAddr
bIsDERP := b.ap.Addr() == tailcfg.DerpMagicIPAddr
// Direct paths beat WebRTC
if a.isDirect() && bIsWebRTC {
return true
}
if b.isDirect() && aIsWebRTC {
return false
}
// WebRTC beats DERP and relay (VNI)
if aIsWebRTC && (bIsDERP || b.vni.IsSet()) {
return true
}
if bIsWebRTC && (aIsDERP || a.vni.IsSet()) {
return false
}
// Each address starts with a set of points (from 0 to 100) that
// represents how much faster they are than the highest-latency
// endpoint. For example, if a has latency 200ms and b has latency
@ -1891,19 +1925,26 @@ func betterAddr(a, b addrQuality) bool {
// addresses, and prefer link-local unicast addresses over other types
// of private IP addresses since it's definitionally more likely that
// they'll be on the same network segment than a general private IP.
if a.ap.Addr().IsLoopback() {
aPoints += 50
} else if a.ap.Addr().IsLinkLocalUnicast() {
aPoints += 30
} else if a.ap.Addr().IsPrivate() {
aPoints += 20
//
// Exclude magic IPs (DERP, WebRTC) from these bonuses as they're not
// real network paths.
if !aIsDERP && !aIsWebRTC {
if a.ap.Addr().IsLoopback() {
aPoints += 50
} else if a.ap.Addr().IsLinkLocalUnicast() {
aPoints += 30
} else if a.ap.Addr().IsPrivate() {
aPoints += 20
}
}
if b.ap.Addr().IsLoopback() {
bPoints += 50
} else if b.ap.Addr().IsLinkLocalUnicast() {
bPoints += 30
} else if b.ap.Addr().IsPrivate() {
bPoints += 20
if !bIsDERP && !bIsWebRTC {
if b.ap.Addr().IsLoopback() {
bPoints += 50
} else if b.ap.Addr().IsLinkLocalUnicast() {
bPoints += 30
} else if b.ap.Addr().IsPrivate() {
bPoints += 20
}
}
// Prefer IPv6 for being a bit more robust, as long as
@ -2035,6 +2076,14 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) {
ps.PeerRelay = udpAddr.String()
} else {
ps.CurAddr = udpAddr.String()
// If this is a WebRTC connection, append the actual remote address
if udpAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr && de.c.webrtcMgr != nil {
if disco := de.disco.Load(); disco != nil {
if remoteAddr := de.c.webrtcMgr.getRemoteAddr(disco.key); remoteAddr.IsValid() {
ps.CurAddr = fmt.Sprintf("%s (%s)", ps.CurAddr, remoteAddr)
}
}
}
}
}
}

View File

@ -94,6 +94,7 @@ type Path string
const (
PathDirectIPv4 Path = "direct_ipv4"
PathDirectIPv6 Path = "direct_ipv6"
PathWebRTC Path = "webrtc"
PathDERP Path = "derp"
PathPeerRelayIPv4 Path = "peer_relay_ipv4"
PathPeerRelayIPv6 Path = "peer_relay_ipv6"
@ -109,6 +110,14 @@ type pathLabel struct {
Path Path
}
// webrtcReadResult is the result of reading a packet from a WebRTC data channel.
// It is similar to derpReadResult but for WebRTC connections.
type webrtcReadResult struct {
n int // length of data in buf
src key.NodePublic // sender's node public key
buf []byte // packet data; nil signals the receiver to ignore this message
}
// metrics in wgengine contains the usermetrics counters for magicsock, it
// is however a bit special. All them metrics are labeled, but looking up
// the metric everytime we need to record it has an overhead, and includes
@ -119,6 +128,7 @@ type metrics struct {
// labeled by the path the packet took.
inboundPacketsIPv4Total expvar.Int
inboundPacketsIPv6Total expvar.Int
inboundPacketsWebRTCTotal expvar.Int
inboundPacketsDERPTotal expvar.Int
inboundPacketsPeerRelayIPv4Total expvar.Int
inboundPacketsPeerRelayIPv6Total expvar.Int
@ -127,6 +137,7 @@ type metrics struct {
// labeled by the path the packet took.
inboundBytesIPv4Total expvar.Int
inboundBytesIPv6Total expvar.Int
inboundBytesWebRTCTotal expvar.Int
inboundBytesDERPTotal expvar.Int
inboundBytesPeerRelayIPv4Total expvar.Int
inboundBytesPeerRelayIPv6Total expvar.Int
@ -135,6 +146,7 @@ type metrics struct {
// labeled by the path the packet took.
outboundPacketsIPv4Total expvar.Int
outboundPacketsIPv6Total expvar.Int
outboundPacketsWebRTCTotal expvar.Int
outboundPacketsDERPTotal expvar.Int
outboundPacketsPeerRelayIPv4Total expvar.Int
outboundPacketsPeerRelayIPv6Total expvar.Int
@ -143,6 +155,7 @@ type metrics struct {
// labeled by the path the packet took.
outboundBytesIPv4Total expvar.Int
outboundBytesIPv6Total expvar.Int
outboundBytesWebRTCTotal expvar.Int
outboundBytesDERPTotal expvar.Int
outboundBytesPeerRelayIPv4Total expvar.Int
outboundBytesPeerRelayIPv6Total expvar.Int
@ -211,6 +224,10 @@ type Conn struct {
// It must have buffer size > 0; see issue 3736.
derpRecvCh chan derpReadResult
// webrtcRecvCh is used by receiveWebRTC to read WebRTC messages.
// It must have buffer size > 0, similar to derpRecvCh.
webrtcRecvCh chan webrtcReadResult
// bind is the wireguard-go conn.Bind for Conn.
bind *connBind
@ -343,6 +360,10 @@ type Conn struct {
// [tailscale.com/net/udprelay.Server] endpoints.
relayManager relayManager
// webrtcMgr manages WebRTC connections for peers.
// May be nil if WebRTC is disabled (no TS_DEBUG_WEBRTC_SIGNALING_URL).
webrtcMgr *webrtcManager
// discoInfo is the state for an active peer DiscoKey.
discoInfo map[key.DiscoPublic]*discoInfo
@ -575,7 +596,8 @@ func newConn(logf logger.Logf) *Conn {
discoPrivate := key.NewDisco()
c := &Conn{
logf: logf,
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
webrtcRecvCh: make(chan webrtcReadResult, 64), // must be buffered, similar to derpRecvCh
derpStarted: make(chan struct{}),
peerLastDerp: make(map[key.NodePublic]int),
peerMap: newPeerMap(),
@ -727,6 +749,16 @@ func NewConn(opts Options) (*Conn, error) {
}
c.logf("magicsock: disco key = %v", c.discoAtomic.Short())
// Initialize WebRTC manager if signaling server URL is set
if signalingURL := debugWebRTCSignalingURL(); signalingURL != "" {
c.logf("magicsock: initializing WebRTC with signaling server %s", signalingURL)
c.webrtcMgr = newWebRTCManager(c, signalingURL)
if c.webrtcMgr == nil {
c.logf("magicsock: failed to initialize WebRTC manager")
}
}
return c, nil
}
@ -736,6 +768,7 @@ func NewConn(opts Options) (*Conn, error) {
func registerMetrics(reg *usermetric.Registry) *metrics {
pathDirectV4 := pathLabel{Path: PathDirectIPv4}
pathDirectV6 := pathLabel{Path: PathDirectIPv6}
pathWebRTC := pathLabel{Path: PathWebRTC}
pathDERP := pathLabel{Path: PathDERP}
pathPeerRelayV4 := pathLabel{Path: PathPeerRelayIPv4}
pathPeerRelayV6 := pathLabel{Path: PathPeerRelayIPv6}
@ -770,21 +803,25 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
// Map clientmetrics to the usermetric counters.
metricRecvDataPacketsIPv4.Register(&m.inboundPacketsIPv4Total)
metricRecvDataPacketsIPv6.Register(&m.inboundPacketsIPv6Total)
metricRecvDataPacketsWebRTC.Register(&m.inboundPacketsWebRTCTotal)
metricRecvDataPacketsDERP.Register(&m.inboundPacketsDERPTotal)
metricRecvDataPacketsPeerRelayIPv4.Register(&m.inboundPacketsPeerRelayIPv4Total)
metricRecvDataPacketsPeerRelayIPv6.Register(&m.inboundPacketsPeerRelayIPv6Total)
metricRecvDataBytesIPv4.Register(&m.inboundBytesIPv4Total)
metricRecvDataBytesIPv6.Register(&m.inboundBytesIPv6Total)
metricRecvDataBytesWebRTC.Register(&m.inboundBytesWebRTCTotal)
metricRecvDataBytesDERP.Register(&m.inboundBytesDERPTotal)
metricRecvDataBytesPeerRelayIPv4.Register(&m.inboundBytesPeerRelayIPv4Total)
metricRecvDataBytesPeerRelayIPv6.Register(&m.inboundBytesPeerRelayIPv6Total)
metricSendDataPacketsIPv4.Register(&m.outboundPacketsIPv4Total)
metricSendDataPacketsIPv6.Register(&m.outboundPacketsIPv6Total)
metricSendDataPacketsWebRTC.Register(&m.outboundPacketsWebRTCTotal)
metricSendDataPacketsDERP.Register(&m.outboundPacketsDERPTotal)
metricSendDataPacketsPeerRelayIPv4.Register(&m.outboundPacketsPeerRelayIPv4Total)
metricSendDataPacketsPeerRelayIPv6.Register(&m.outboundPacketsPeerRelayIPv6Total)
metricSendDataBytesIPv4.Register(&m.outboundBytesIPv4Total)
metricSendDataBytesIPv6.Register(&m.outboundBytesIPv6Total)
metricSendDataBytesWebRTC.Register(&m.outboundBytesWebRTCTotal)
metricSendDataBytesDERP.Register(&m.outboundBytesDERPTotal)
metricSendDataBytesPeerRelayIPv4.Register(&m.outboundBytesPeerRelayIPv4Total)
metricSendDataBytesPeerRelayIPv6.Register(&m.outboundBytesPeerRelayIPv6Total)
@ -796,24 +833,28 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
inboundPacketsTotal.Set(pathDirectV4, &m.inboundPacketsIPv4Total)
inboundPacketsTotal.Set(pathDirectV6, &m.inboundPacketsIPv6Total)
inboundPacketsTotal.Set(pathWebRTC, &m.inboundPacketsWebRTCTotal)
inboundPacketsTotal.Set(pathDERP, &m.inboundPacketsDERPTotal)
inboundPacketsTotal.Set(pathPeerRelayV4, &m.inboundPacketsPeerRelayIPv4Total)
inboundPacketsTotal.Set(pathPeerRelayV6, &m.inboundPacketsPeerRelayIPv6Total)
inboundBytesTotal.Set(pathDirectV4, &m.inboundBytesIPv4Total)
inboundBytesTotal.Set(pathDirectV6, &m.inboundBytesIPv6Total)
inboundBytesTotal.Set(pathWebRTC, &m.inboundBytesWebRTCTotal)
inboundBytesTotal.Set(pathDERP, &m.inboundBytesDERPTotal)
inboundBytesTotal.Set(pathPeerRelayV4, &m.inboundBytesPeerRelayIPv4Total)
inboundBytesTotal.Set(pathPeerRelayV6, &m.inboundBytesPeerRelayIPv6Total)
outboundPacketsTotal.Set(pathDirectV4, &m.outboundPacketsIPv4Total)
outboundPacketsTotal.Set(pathDirectV6, &m.outboundPacketsIPv6Total)
outboundPacketsTotal.Set(pathWebRTC, &m.outboundPacketsWebRTCTotal)
outboundPacketsTotal.Set(pathDERP, &m.outboundPacketsDERPTotal)
outboundPacketsTotal.Set(pathPeerRelayV4, &m.outboundPacketsPeerRelayIPv4Total)
outboundPacketsTotal.Set(pathPeerRelayV6, &m.outboundPacketsPeerRelayIPv6Total)
outboundBytesTotal.Set(pathDirectV4, &m.outboundBytesIPv4Total)
outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total)
outboundBytesTotal.Set(pathWebRTC, &m.outboundBytesWebRTCTotal)
outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal)
outboundBytesTotal.Set(pathPeerRelayV4, &m.outboundBytesPeerRelayIPv4Total)
outboundBytesTotal.Set(pathPeerRelayV6, &m.outboundBytesPeerRelayIPv6Total)
@ -828,21 +869,25 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
func deregisterMetrics() {
metricRecvDataPacketsIPv4.UnregisterAll()
metricRecvDataPacketsIPv6.UnregisterAll()
metricRecvDataPacketsWebRTC.UnregisterAll()
metricRecvDataPacketsDERP.UnregisterAll()
metricRecvDataPacketsPeerRelayIPv4.UnregisterAll()
metricRecvDataPacketsPeerRelayIPv6.UnregisterAll()
metricRecvDataBytesIPv4.UnregisterAll()
metricRecvDataBytesIPv6.UnregisterAll()
metricRecvDataBytesWebRTC.UnregisterAll()
metricRecvDataBytesDERP.UnregisterAll()
metricRecvDataBytesPeerRelayIPv4.UnregisterAll()
metricRecvDataBytesPeerRelayIPv6.UnregisterAll()
metricSendDataPacketsIPv4.UnregisterAll()
metricSendDataPacketsIPv6.UnregisterAll()
metricSendDataPacketsWebRTC.UnregisterAll()
metricSendDataPacketsDERP.UnregisterAll()
metricSendDataPacketsPeerRelayIPv4.UnregisterAll()
metricSendDataPacketsPeerRelayIPv6.UnregisterAll()
metricSendDataBytesIPv4.UnregisterAll()
metricSendDataBytesIPv6.UnregisterAll()
metricSendDataBytesWebRTC.UnregisterAll()
metricSendDataBytesDERP.UnregisterAll()
metricSendDataBytesPeerRelayIPv4.UnregisterAll()
metricSendDataBytesPeerRelayIPv6.UnregisterAll()
@ -1630,6 +1675,9 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error)
// IPv6 address when the local machine doesn't have IPv6 support
// returns (false, nil); it's not an error, but nothing was sent.
func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool, isGeneveEncap bool) (sent bool, err error) {
if addr.Addr() == tailcfg.WebRTCMagicIPAddr {
return c.sendWebRTC(addr, pubKey, b)
}
if addr.Addr() != tailcfg.DerpMagicIPAddr {
return c.sendUDP(addr, b, isDisco, isGeneveEncap)
}
@ -1671,6 +1719,170 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, is
return false, errDropDerpPacket
}
// webrtcBatchMagic is the first byte of a batched WebRTC SCTP message.
// WireGuard packets start with 0x010x04, disco packets start with 0x54 ('T'),
// so 0xBA is unambiguous.
const webrtcBatchMagic = byte(0xBA)
// sendWebRTCBatch packs all buffs into a single SCTP message and sends it.
// Batching is the critical throughput optimization: the per-packet WebRTC
// path calls rwc.Write once per WireGuard packet, producing one SCTP message,
// one DTLS record, and one UDP send per packet. Packing N packets into one
// SCTP message reduces that to a single write — the same advantage
// sendUDPBatch (sendmmsg) gives the regular UDP path.
//
// Wire format for N>1: [0xBA magic][2-byte BE len][packet]...[2-byte BE len][packet]
// Single packet: sent as-is with no framing overhead.
func (c *Conn) sendWebRTCBatch(pubKey key.NodePublic, buffs [][]byte, offset int) error {
if c.webrtcMgr == nil {
return nil
}
// Resolve endpoint and disco key once for the whole batch.
c.mu.Lock()
ep, ok := c.peerMap.endpointForNodeKey(pubKey)
c.mu.Unlock()
if !ok || ep == nil {
return nil
}
disco := ep.disco.Load()
if disco == nil {
return nil
}
if len(buffs) == 1 {
// Fast path: single packet, no framing overhead.
b := buffs[0][offset:]
if err := c.webrtcMgr.sendPacket(disco.key, b); err != nil {
return err
}
c.metrics.outboundPacketsWebRTCTotal.Add(1)
c.metrics.outboundBytesWebRTCTotal.Add(int64(len(b)))
return nil
}
// Multi-packet batch path.
size := 1 // magic byte
for _, b := range buffs {
size += 2 + len(b[offset:])
}
batch := make([]byte, size)
batch[0] = webrtcBatchMagic
pos := 1
var totalBytes int64
for _, b := range buffs {
pkt := b[offset:]
binary.BigEndian.PutUint16(batch[pos:], uint16(len(pkt)))
pos += 2
copy(batch[pos:], pkt)
pos += len(pkt)
totalBytes += int64(len(pkt))
}
if err := c.webrtcMgr.sendPacket(disco.key, batch); err != nil {
return err
}
c.metrics.outboundPacketsWebRTCTotal.Add(int64(len(buffs)))
c.metrics.outboundBytesWebRTCTotal.Add(totalBytes)
return nil
}
// sendWebRTC sends a packet over WebRTC data channel.
func (c *Conn) sendWebRTC(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) {
if c.webrtcMgr == nil {
return false, nil
}
// Find the endpoint by public key
c.mu.Lock()
ep, ok := c.peerMap.endpointForNodeKey(pubKey)
c.mu.Unlock()
if !ok {
return false, nil
}
// Get the disco key for this endpoint
disco := ep.disco.Load()
if disco == nil {
return false, nil
}
// Send via WebRTC manager
if err := c.webrtcMgr.sendPacket(disco.key, b); err != nil {
return false, err
}
// Update metrics
c.metrics.outboundPacketsWebRTCTotal.Add(1)
c.metrics.outboundBytesWebRTCTotal.Add(int64(len(b)))
return true, nil
}
// receiveWebRTC handles packets received from WebRTC data channels.
// This is called by webrtcManager when data arrives on a data channel.
// receiveWebRTC is called by webrtcManager when data arrives on a WebRTC data channel.
// It queues the packet for processing by wireguard-go through the webrtcRecvCh channel.
func (c *Conn) receiveWebRTC(b []byte, srcNodeKey key.NodePublic) {
// Copy into a fresh slice: b belongs to the reader goroutine's reusable
// buffer which will be overwritten on the next Read call.
pkt := make([]byte, len(b))
copy(pkt, b)
select {
case c.webrtcRecvCh <- webrtcReadResult{n: len(pkt), src: srcNodeKey, buf: pkt}:
case <-c.connCtx.Done():
default:
c.logf("webrtc: dropped packet from %v, receive channel full", srcNodeKey.ShortString())
}
}
// processWebRTCReadResult processes a WebRTC packet received from the webrtcRecvCh.
// It's similar to processDERPReadResult but for WebRTC packets.
func (c *Conn) processWebRTCReadResult(wr webrtcReadResult, b []byte) (n int, ep *endpoint) {
if wr.buf == nil {
return 0, nil
}
n = wr.n
ncopy := copy(b, wr.buf[:n])
if ncopy != n {
err := fmt.Errorf("received WebRTC packet of length %d that's too big for WireGuard buf size %d", n, ncopy)
c.logf("magicsock: %v", err)
return 0, nil
}
srcAddr := epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}
// Check if this looks like a disco packet
pt, isGeneveEncap := packetLooksLike(b[:n])
if pt == packetLooksLikeDisco && !isGeneveEncap {
c.handleDiscoMessage(b[:n], srcAddr, false, wr.src, discoRXPathWebRTC)
return 0, nil
}
// Find the endpoint by node key
var ok bool
c.mu.Lock()
ep, ok = c.peerMap.endpointForNodeKey(wr.src)
c.mu.Unlock()
if !ok {
// We don't know anything about this node key
return 0, nil
}
ep.noteRecvActivity(srcAddr, mono.Now())
if update := c.connCounter.Load(); update != nil {
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, n, true)
}
c.metrics.inboundPacketsWebRTCTotal.Add(1)
c.metrics.inboundBytesWebRTCTotal.Add(int64(n))
return n, ep
}
type receiveBatch struct {
msgs []ipv6.Message
}
@ -2005,7 +2217,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.
if !dstKey.IsZero() {
node = dstKey.ShortString()
}
c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt))
c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, pathStr(dst.String()), disco.MessageSummary(m), len(pkt))
}
if isDERP {
metricSentDiscoDERP.Add(1)
@ -2045,6 +2257,7 @@ type discoRXPath string
const (
discoRXPathUDP discoRXPath = "UDP socket"
discoRXPathDERP discoRXPath = "DERP"
discoRXPathWebRTC discoRXPath = "WebRTC"
discoRXPathRawSocket discoRXPath = "raw socket"
)
@ -2356,13 +2569,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
if isVia {
c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints",
c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(),
ep.publicKey.ShortString(), derpStr(src.String()),
ep.publicKey.ShortString(), pathStr(src.String()),
len(via.AddrPorts))
c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via)
} else {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
c.discoAtomic.Short(), epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
ep.publicKey.ShortString(), pathStr(src.String()),
len(cmm.MyNumber))
go ep.handleCallMeMaybe(cmm)
}
@ -2408,7 +2621,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
if isResp {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints",
c.discoAtomic.Short(), epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
ep.publicKey.ShortString(), pathStr(src.String()),
msgType,
len(resp.AddrPorts))
c.relayManager.handleRxDiscoMsg(c, resp, nodeKey, di.discoKey, src)
@ -2422,7 +2635,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
} else {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v",
c.discoAtomic.Short(), epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
ep.publicKey.ShortString(), pathStr(src.String()),
msgType,
req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString())
}
@ -3107,6 +3320,11 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee
}
ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn)
c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap
// Start WebRTC connection if not already started
if c.webrtcMgr != nil && !n.DiscoKey().IsZero() {
c.webrtcMgr.startConnection(ep)
}
continue
}
@ -3170,6 +3388,11 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee
ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn)
c.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
// Start WebRTC connection to this peer if WebRTC is enabled
if c.webrtcMgr != nil && !n.DiscoKey().IsZero() {
c.webrtcMgr.startConnection(ep)
}
}
// If the set of nodes changed since the last SetNetworkMap, the
@ -3287,9 +3510,9 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error)
return nil, 0, errors.New("magicsock: connBind already open")
}
c.closed = false
fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP}
fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP, c.receiveWebRTC}
if runtime.GOOS == "js" {
fns = []conn.ReceiveFunc{c.receiveDERP}
fns = []conn.ReceiveFunc{c.receiveDERP, c.receiveWebRTC}
}
// TODO: Combine receiveIPv4 and receiveIPv6 and receiveIP into a single
// closure that closes over a *RebindingUDPConn?
@ -3366,6 +3589,14 @@ func (c *Conn) Close() error {
ep.stopAndReset()
})
// Close WebRTC manager if initialized
if c.webrtcMgr != nil {
c.webrtcMgr.close()
c.webrtcMgr = nil
}
close(c.webrtcRecvCh)
c.closed = true
c.connCtxCancel()
c.closeAllDerpLocked("conn-close")
@ -3920,9 +4151,20 @@ func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) {
}
}
// pathStr formats endpoint addresses for display, replacing magic IPs with readable names.
// It replaces DERP IPs with "derp-" and WebRTC IPs with "webrtc-".
func pathStr(s string) string {
s = derpStr(s)
s = webrtcStr(s)
return s
}
// derpStr replaces DERP IPs in s with "derp-".
func derpStr(s string) string { return strings.ReplaceAll(s, "127.3.3.40:", "derp-") }
// webrtcStr replaces WebRTC IPs in s with "webrtc-".
func webrtcStr(s string) string { return strings.ReplaceAll(s, "127.3.3.41:", "webrtc-") }
// epAddrEndpointCache is a mutex-free single-element cache, mapping from
// a single [epAddr] to a single [*endpoint].
type epAddrEndpointCache struct {
@ -4003,11 +4245,13 @@ var (
metricRecvDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_derp")
metricRecvDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv4")
metricRecvDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv6")
metricRecvDataPacketsWebRTC = clientmetric.NewAggregateCounter("magicsock_recv_data_webrtc")
metricRecvDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv4")
metricRecvDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv6")
metricSendDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_send_data_derp")
metricSendDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv4")
metricSendDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv6")
metricSendDataPacketsWebRTC = clientmetric.NewAggregateCounter("magicsock_send_data_webrtc")
metricSendDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv4")
metricSendDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv6")
@ -4015,11 +4259,13 @@ var (
metricRecvDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_derp")
metricRecvDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv4")
metricRecvDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv6")
metricRecvDataBytesWebRTC = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_webrtc")
metricRecvDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv4")
metricRecvDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv6")
metricSendDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_derp")
metricSendDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv4")
metricSendDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv6")
metricSendDataBytesWebRTC = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_webrtc")
metricSendDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv4")
metricSendDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv6")
@ -4139,6 +4385,16 @@ func (c *Conn) SetLastNetcheckReportForTest(ctx context.Context, report *netchec
c.lastNetCheckReport.Store(report)
}
// findEndpointByDisco returns the first endpoint with the given disco key, or nil if not found.
func (c *Conn) findEndpointByDisco(dk key.DiscoPublic) *endpoint {
var found *endpoint
c.peerMap.forEachEndpointWithDiscoKey(dk, func(ep *endpoint) bool {
found = ep
return false // stop after first match
})
return found
}
// lazyEndpoint is a wireguard [conn.Endpoint] for when magicsock received a
// non-disco (presumably WireGuard) packet from a UDP address from which we
// can't map to a Tailscale peer. But WireGuard most likely can, once it

View File

@ -0,0 +1,747 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pion/webrtc/v4"
"github.com/tailscale/wireguard-go/conn"
"tailscale.com/rtclib"
"tailscale.com/tailcfg"
"tailscale.com/tstime/mono"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
// webrtcConnState represents the state of a WebRTC connection.
type webrtcConnState int
const (
webrtcStateIdle webrtcConnState = iota
webrtcStateConnecting
webrtcStateConnected
webrtcStateFailed
webrtcStateClosed
)
// dataChannelRW is the detached io.ReadWriteCloser for a WebRTC DataChannel.
// It is stored via atomic.Pointer so the hot send path can retrieve it without
// holding the webrtcManager mutex.
type dataChannelRW struct {
io.ReadWriteCloser
}
// webrtcPeerState tracks WebRTC connection state for a single peer.
type webrtcPeerState struct {
ep *endpoint
peerConn *webrtc.PeerConnection
dataChannel *webrtc.DataChannel
dcRW atomic.Pointer[dataChannelRW] // non-nil once the DataChannel is open
localDisco key.DiscoPublic
remoteDisco key.DiscoPublic
remoteNodeKey key.NodePublic // peer's node public key (for WireGuard)
remoteAddr netip.AddrPort // actual remote address from ICE candidate
state webrtcConnState
lastError error
createdAt time.Time
}
// webrtcConnectionReadyEvent signals that a WebRTC connection is ready.
type webrtcConnectionReadyEvent struct {
remoteDisco key.DiscoPublic
ep *endpoint
}
// webrtcManager manages WebRTC connections for magicsock.
type webrtcManager struct {
logf logger.Logf
conn *Conn // parent magicsock.Conn
mu sync.RWMutex
peerConnectionsByEndpoint map[*endpoint]*webrtcPeerState
peerConnectionsByDisco map[key.DiscoPublic]*webrtcPeerState
signalingClient *signalingClient
// Control channels
startConnectionCh chan *endpoint
connectionReadyCh chan webrtcConnectionReadyEvent
closeCh chan struct{}
runLoopStoppedCh chan struct{}
// WebRTC API configuration
api *webrtc.API
}
// Ensure webrtcManager implements rtclib.SignalHandler interface.
var _ rtclib.SignalHandler = (*webrtcManager)(nil)
// newWebRTCManager creates a new WebRTC manager.
func newWebRTCManager(c *Conn, signalingURL string) *webrtcManager {
mgr := newWebRTCManagerBase(c, signalingURL)
// Create and start signaling client
mgr.signalingClient = newSignalingClient(signalingURL, c.logf)
if err := mgr.signalingClient.Start(mgr); err != nil {
c.logf("webrtc: failed to start signaling client: %v", err)
return nil
}
// Start event loop
go mgr.runLoop()
return mgr
}
// close shuts down the WebRTC manager.
func (m *webrtcManager) close() error {
// Close signaling client first to stop new messages
if m.signalingClient != nil {
if err := m.signalingClient.Close(); err != nil {
m.logf("webrtc: signaling client close error: %v", err)
}
}
// Signal runLoop to stop
close(m.closeCh)
// Wait for runLoop to finish with timeout
select {
case <-m.runLoopStoppedCh:
case <-time.After(2 * time.Second):
m.logf("webrtc: close timed out, forcing shutdown")
}
m.mu.Lock()
defer m.mu.Unlock()
// Close all peer connections
for _, ps := range m.peerConnectionsByEndpoint {
if ps.peerConn != nil {
ps.peerConn.Close()
}
}
m.peerConnectionsByEndpoint = nil
m.peerConnectionsByDisco = nil
return nil
}
// startConnection initiates a WebRTC connection to an endpoint.
func (m *webrtcManager) startConnection(ep *endpoint) {
select {
case m.startConnectionCh <- ep:
case <-m.closeCh:
default:
m.logf("webrtc: startConnection queue full for %v", ep.nodeAddr)
}
}
// deliverWebRTCMsg delivers one DataChannel message to the receive pipeline.
// It handles both single packets and batches (webrtcBatchMagic framing) so
// the logic is shared between the native detached-reader path and the
// JS/fallback OnMessage callback path.
func (m *webrtcManager) deliverWebRTCMsg(ps *webrtcPeerState, data []byte) {
if len(data) == 0 {
return
}
// Batch: [0xBA magic][2-byte BE len][pkt]...[2-byte BE len][pkt]
if data[0] == webrtcBatchMagic {
data = data[1:]
for len(data) >= 2 {
pktLen := int(binary.BigEndian.Uint16(data))
data = data[2:]
if pktLen > len(data) {
m.logf("webrtc: batch framing error for peer %v: pktLen %d > remaining %d",
ps.remoteDisco.ShortString(), pktLen, len(data))
return
}
m.conn.receiveWebRTC(data[:pktLen], ps.remoteNodeKey)
data = data[pktLen:]
}
return
}
m.conn.receiveWebRTC(data, ps.remoteNodeKey)
}
// runDataChannelReader is the per-peer receive loop used when DetachDataChannels
// is enabled (native builds). It reads directly from the detached io.ReadWriteCloser
// into a reused buffer, avoiding the per-message goroutine wakeup and allocation
// that the OnMessage callback path incurs.
func (m *webrtcManager) runDataChannelReader(ps *webrtcPeerState, rwc io.ReadWriteCloser) {
// Size the buffer to hold the largest possible batch.
// 64 WireGuard packets × ~1420 bytes + framing < 100 KiB; 256 KiB is safe.
buf := make([]byte, 256*1024)
for {
n, err := rwc.Read(buf)
if err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, net.ErrClosed) {
m.logf("webrtc: data channel read error for peer %v: %v", ps.remoteDisco.ShortString(), err)
}
ps.dcRW.Store(nil)
return
}
if n > 0 {
m.deliverWebRTCMsg(ps, buf[:n])
}
}
}
// getRemoteAddr returns the actual remote address for a WebRTC peer connection.
func (m *webrtcManager) getRemoteAddr(disco key.DiscoPublic) netip.AddrPort {
m.mu.RLock()
defer m.mu.RUnlock()
if ps, ok := m.peerConnectionsByDisco[disco]; ok && ps.state == webrtcStateConnected {
return ps.remoteAddr
}
return netip.AddrPort{}
}
// runLoop is the main event loop for the WebRTC manager.
func (m *webrtcManager) runLoop() {
defer close(m.runLoopStoppedCh)
for {
select {
case ep := <-m.startConnectionCh:
m.handleStartConnection(ep)
case event := <-m.connectionReadyCh:
m.handleConnectionReady(event)
case <-m.closeCh:
return
}
}
}
// handleStartConnection creates a new WebRTC connection to an endpoint.
func (m *webrtcManager) handleStartConnection(ep *endpoint) {
m.mu.Lock()
// Check if we already have a connection
if ps, exists := m.peerConnectionsByEndpoint[ep]; exists {
if ps.state == webrtcStateConnecting || ps.state == webrtcStateConnected {
m.mu.Unlock()
return
}
}
// Get disco keys
localDisco := m.conn.DiscoPublicKey()
disco := ep.disco.Load()
if disco == nil {
m.mu.Unlock()
m.logf("webrtc: cannot start connection, peer has no disco key")
return
}
remoteDisco := disco.key
m.logf("webrtc: starting connection to peer %v (disco %v)", ep.nodeAddr, remoteDisco.ShortString())
m.mu.Unlock()
// Create peer connection
config := webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
URLs: []string{"stun:stun.l.google.com:19302"},
},
},
ICETransportPolicy: webrtc.ICETransportPolicyAll,
}
peerConn, err := m.api.NewPeerConnection(config)
if err != nil {
m.logf("webrtc: failed to create peer connection: %v", err)
return
}
ps := &webrtcPeerState{
ep: ep,
peerConn: peerConn,
localDisco: localDisco,
remoteDisco: remoteDisco,
remoteNodeKey: ep.publicKey,
state: webrtcStateConnecting,
createdAt: time.Now(),
}
// Store peer state
m.mu.Lock()
m.peerConnectionsByEndpoint[ep] = ps
m.peerConnectionsByDisco[remoteDisco] = ps
m.mu.Unlock()
// Set up connection state handler
peerConn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
m.handleConnectionStateChange(ps, state)
})
// Set up ICE candidate handler
peerConn.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate != nil {
m.handleLocalICECandidate(ps, candidate)
}
})
// Create an unordered, unreliable data channel (MaxRetransmits=0).
// WireGuard is designed to run over raw UDP, which is unordered and
// unreliable. Using an ordered/reliable DataChannel (the default) wraps
// WireGuard in SCTP's reliable-ordered-stream semantics, causing
// head-of-line blocking whenever a packet is lost: SCTP holds back all
// subsequent packets until the missing one is retransmitted and delivered
// in order. That is why throughput over WebRTC was worse than DERP.
// Setting Ordered=false and MaxRetransmits=0 makes the DataChannel behave
// like a UDP socket, which is exactly what WireGuard expects.
unordered := false
maxRetransmits := uint16(0)
dataChannel, err := peerConn.CreateDataChannel("tailscale-wg", &webrtc.DataChannelInit{
Ordered: &unordered,
MaxRetransmits: &maxRetransmits,
})
if err != nil {
m.logf("webrtc: failed to create data channel: %v", err)
peerConn.Close()
return
}
ps.dataChannel = dataChannel
// Set up data channel handlers.
// With DetachDataChannels enabled, OnMessage cannot be used. Instead we
// call Detach() inside OnOpen to get a raw io.ReadWriteCloser and spin
// up a dedicated reader goroutine, which eliminates per-packet callback
// overhead and goroutine wakeups.
setOnError(dataChannel, func(err error) {
m.logf("webrtc: data channel error for peer %v: %v", remoteDisco.ShortString(), err)
})
dataChannel.OnOpen(func() {
// Native: DetachDataChannels was enabled; get a raw io.ReadWriteCloser
// and spin a dedicated reader goroutine (zero per-message allocations).
// JS/fallback: Detach() returns an error; fall back to OnMessage
// callbacks, which is the only API available in the browser.
if rwc, err := dataChannel.Detach(); err == nil {
ps.dcRW.Store(&dataChannelRW{rwc})
go m.runDataChannelReader(ps, rwc)
} else {
dataChannel.OnMessage(func(msg webrtc.DataChannelMessage) {
m.deliverWebRTCMsg(ps, msg.Data)
})
}
m.logf("webrtc: data channel opened for peer %v", remoteDisco.ShortString())
m.connectionReadyCh <- webrtcConnectionReadyEvent{
remoteDisco: remoteDisco,
ep: ep,
}
})
// Create and send offer
offer, err := peerConn.CreateOffer(nil)
if err != nil {
m.logf("webrtc: failed to create offer: %v", err)
peerConn.Close()
return
}
if err := peerConn.SetLocalDescription(offer); err != nil {
m.logf("webrtc: failed to set local description: %v", err)
peerConn.Close()
return
}
// Send offer via signaling
if err := m.signalingClient.Offer(localDisco.String(), remoteDisco.String(), &offer); err != nil {
m.logf("webrtc: failed to send offer: %v", err)
peerConn.Close()
return
}
m.logf("webrtc: sent offer to peer %v", remoteDisco.ShortString())
}
// HandleOffer implements rtclib.SignalHandler.
func (m *webrtcManager) HandleOffer(from, to string, offer *webrtc.SessionDescription) {
m.logf("webrtc: received offer from=%s", from)
var remoteDisco key.DiscoPublic
if err := remoteDisco.UnmarshalText([]byte(from)); err != nil {
m.logf("webrtc: invalid sender disco key: %v", err)
return
}
m.handleRemoteOffer(remoteDisco, offer)
}
// HandleAnswer implements rtclib.SignalHandler.
func (m *webrtcManager) HandleAnswer(from, to string, answer *webrtc.SessionDescription) {
m.logf("webrtc: received answer from=%s", from)
var remoteDisco key.DiscoPublic
if err := remoteDisco.UnmarshalText([]byte(from)); err != nil {
m.logf("webrtc: invalid sender disco key: %v", err)
return
}
m.handleRemoteAnswer(remoteDisco, answer)
}
// HandleCandidate implements rtclib.SignalHandler.
func (m *webrtcManager) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) {
m.logf("webrtc: received candidate from=%s", from)
var remoteDisco key.DiscoPublic
if err := remoteDisco.UnmarshalText([]byte(from)); err != nil {
m.logf("webrtc: invalid sender disco key: %v", err)
return
}
m.handleRemoteCandidate(remoteDisco, candidate)
}
// handleRemoteOffer processes an incoming offer from a peer.
func (m *webrtcManager) handleRemoteOffer(remoteDisco key.DiscoPublic, offer *webrtc.SessionDescription) {
// For incoming connections, we need to find the endpoint by disco key
m.mu.Lock()
ps, exists := m.peerConnectionsByDisco[remoteDisco]
m.mu.Unlock()
if !exists {
// We received an offer but don't have a connection yet.
// This happens when the remote peer initiated first (glare scenario).
// Find the endpoint by disco key and create peer connection state.
ep := m.conn.findEndpointByDisco(remoteDisco)
if ep == nil {
m.logf("webrtc: received offer from unknown peer %v with no endpoint", remoteDisco.ShortString())
return
}
m.logf("webrtc: received offer from peer %v, creating answerer connection", remoteDisco.ShortString())
// Create peer connection for incoming offer
config := webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
URLs: []string{"stun:stun.l.google.com:19302"},
},
},
ICETransportPolicy: webrtc.ICETransportPolicyAll,
}
peerConn, err := m.api.NewPeerConnection(config)
if err != nil {
m.logf("webrtc: failed to create peer connection for incoming offer: %v", err)
return
}
localDisco := m.conn.DiscoPublicKey()
ps = &webrtcPeerState{
ep: ep,
peerConn: peerConn,
localDisco: localDisco,
remoteDisco: remoteDisco,
remoteNodeKey: ep.publicKey,
state: webrtcStateConnecting,
createdAt: time.Now(),
}
// Store peer state
m.mu.Lock()
m.peerConnectionsByEndpoint[ep] = ps
m.peerConnectionsByDisco[remoteDisco] = ps
m.mu.Unlock()
// Set up connection state handler
peerConn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
m.handleConnectionStateChange(ps, state)
})
// Set up ICE candidate handler
peerConn.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate != nil {
m.handleLocalICECandidate(ps, candidate)
}
})
// Set up data channel handler (for answerer, we wait for the data channel from offerer).
peerConn.OnDataChannel(func(dc *webrtc.DataChannel) {
m.logf("webrtc: received data channel from peer %v", remoteDisco.ShortString())
ps.dataChannel = dc
setOnError(dc, func(err error) {
m.logf("webrtc: data channel error for peer %v: %v", remoteDisco.ShortString(), err)
})
dc.OnOpen(func() {
if rwc, err := dc.Detach(); err == nil {
ps.dcRW.Store(&dataChannelRW{rwc})
go m.runDataChannelReader(ps, rwc)
} else {
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
m.deliverWebRTCMsg(ps, msg.Data)
})
}
m.logf("webrtc: data channel opened for peer %v", remoteDisco.ShortString())
m.connectionReadyCh <- webrtcConnectionReadyEvent{
remoteDisco: remoteDisco,
ep: ep,
}
})
})
}
if err := ps.peerConn.SetRemoteDescription(*offer); err != nil {
m.logf("webrtc: failed to set remote description: %v", err)
return
}
// Create answer
answer, err := ps.peerConn.CreateAnswer(nil)
if err != nil {
m.logf("webrtc: failed to create answer: %v", err)
return
}
if err := ps.peerConn.SetLocalDescription(answer); err != nil {
m.logf("webrtc: failed to set local description: %v", err)
return
}
// Send answer via signaling
if err := m.signalingClient.Answer(ps.localDisco.String(), remoteDisco.String(), &answer); err != nil {
m.logf("webrtc: failed to send answer: %v", err)
return
}
m.logf("webrtc: sent answer to peer %v", remoteDisco.ShortString())
}
// handleRemoteAnswer processes an incoming answer from a peer.
func (m *webrtcManager) handleRemoteAnswer(remoteDisco key.DiscoPublic, answer *webrtc.SessionDescription) {
m.mu.Lock()
ps, exists := m.peerConnectionsByDisco[remoteDisco]
m.mu.Unlock()
if !exists {
m.logf("webrtc: received answer from unknown peer %v", remoteDisco.ShortString())
return
}
if err := ps.peerConn.SetRemoteDescription(*answer); err != nil {
m.logf("webrtc: failed to set remote description: %v", err)
return
}
m.logf("webrtc: set remote description for peer %v", remoteDisco.ShortString())
}
// handleRemoteCandidate processes an incoming ICE candidate from a peer.
func (m *webrtcManager) handleRemoteCandidate(remoteDisco key.DiscoPublic, candidate *webrtc.ICECandidateInit) {
m.mu.Lock()
ps, exists := m.peerConnectionsByDisco[remoteDisco]
m.mu.Unlock()
if !exists {
m.logf("webrtc: received candidate from unknown peer %v", remoteDisco.ShortString())
return
}
// Try to extract the remote address from the candidate string
// Candidate format: "candidate:... udp ... <ip> <port> typ ..."
if candidate.Candidate != "" {
if addr := parseICECandidateAddr(candidate.Candidate); addr.IsValid() {
m.mu.Lock()
ps.remoteAddr = addr
m.mu.Unlock()
m.logf("webrtc: peer %v candidate address: %v", remoteDisco.ShortString(), addr)
}
}
if err := ps.peerConn.AddICECandidate(*candidate); err != nil {
m.logf("webrtc: failed to add ICE candidate: %v", err)
return
}
m.logf("webrtc: added ICE candidate for peer %v", remoteDisco.ShortString())
}
// parseICECandidateAddr extracts the IP:port from an ICE candidate SDP string.
// Example candidate: "candidate:1234 1 udp 2130706431 192.168.1.100 54321 typ host"
func parseICECandidateAddr(candidate string) netip.AddrPort {
fields := strings.Fields(candidate)
// Format: candidate:<foundation> <component> <protocol> <priority> <ip> <port> typ <type>
if len(fields) < 7 {
return netip.AddrPort{}
}
ip := fields[4]
port := fields[5]
addr, err := netip.ParseAddr(ip)
if err != nil {
return netip.AddrPort{}
}
var portNum uint16
if _, err := fmt.Sscanf(port, "%d", &portNum); err != nil {
return netip.AddrPort{}
}
return netip.AddrPortFrom(addr, portNum)
}
// handleLocalICECandidate sends a local ICE candidate to a peer via signaling.
func (m *webrtcManager) handleLocalICECandidate(ps *webrtcPeerState, candidate *webrtc.ICECandidate) {
candidateInit := candidate.ToJSON()
if err := m.signalingClient.Candidate(ps.localDisco.String(), ps.remoteDisco.String(), &candidateInit); err != nil {
m.logf("webrtc: failed to send candidate: %v", err)
return
}
m.logf("webrtc: sent ICE candidate to peer %v", ps.remoteDisco.ShortString())
}
// handleConnectionStateChange handles WebRTC connection state changes.
func (m *webrtcManager) handleConnectionStateChange(ps *webrtcPeerState, state webrtc.PeerConnectionState) {
m.logf("webrtc: connection state changed to %s for peer %v", state.String(), ps.remoteDisco.ShortString())
m.mu.Lock()
defer m.mu.Unlock()
switch state {
case webrtc.PeerConnectionStateConnected:
ps.state = webrtcStateConnected
// Log the selected ICE candidate pair so we can confirm the actual
// data path (LAN host candidate vs. STUN server-reflexive vs. relay).
go func() {
cp, err := ps.peerConn.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
if err != nil || cp == nil {
m.logf("webrtc: peer %v connected (selected candidate pair unavailable: %v)",
ps.remoteDisco.ShortString(), err)
return
}
m.logf("webrtc: peer %v connected via %s:%d → %s:%d (local %s, remote %s)",
ps.remoteDisco.ShortString(),
cp.Local.Address, cp.Local.Port,
cp.Remote.Address, cp.Remote.Port,
cp.Local.Typ, cp.Remote.Typ)
}()
case webrtc.PeerConnectionStateFailed:
ps.state = webrtcStateFailed
ps.lastError = errors.New("connection failed")
ps.dcRW.Store(nil)
case webrtc.PeerConnectionStateClosed:
ps.state = webrtcStateClosed
ps.dcRW.Store(nil)
case webrtc.PeerConnectionStateDisconnected:
// Transient state, keep current state
}
}
// handleConnectionReady marks a WebRTC connection as ready and updates endpoint.
func (m *webrtcManager) handleConnectionReady(event webrtcConnectionReadyEvent) {
m.logf("webrtc: connection ready for peer %v", event.remoteDisco.ShortString())
// Update endpoint to use WebRTC path
event.ep.mu.Lock()
defer event.ep.mu.Unlock()
// Use a fixed port number for WebRTC connections (similar to DERP)
// The magic IP identifies this as WebRTC, not UDP
webrtcAddr := addrQuality{
epAddr: epAddr{
ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345),
},
latency: 0, // Will be determined by disco pings, same as DERP
}
// Set as bestAddr if better than current
now := mono.Now()
if betterAddr(webrtcAddr, event.ep.bestAddr) {
event.ep.bestAddr = webrtcAddr
event.ep.bestAddrAt = now
event.ep.trustBestAddrUntil = now.Add(5 * time.Minute)
m.logf("webrtc: updated endpoint %v with WebRTC path", event.ep.nodeAddr)
}
}
// sendPacket sends a packet over a WebRTC data channel.
// The hot path is lock-free: we take a read-lock (not write-lock) to look up
// the peer state, then do an atomic load for the detached channel. Multiple
// concurrent senders for different peers never contend.
func (m *webrtcManager) sendPacket(disco key.DiscoPublic, b []byte) error {
m.mu.RLock()
ps, ok := m.peerConnectionsByDisco[disco]
m.mu.RUnlock()
if !ok {
return errors.New("no WebRTC connection")
}
// Native path: DetachDataChannels was enabled; use the raw io.ReadWriteCloser.
if rw := ps.dcRW.Load(); rw != nil {
if _, err := rw.Write(b); err != nil {
return fmt.Errorf("send failed: %w", err)
}
return nil
}
// JS/fallback path: use DataChannel.Send() directly.
dc := ps.dataChannel
if dc == nil || dc.ReadyState() != webrtc.DataChannelStateOpen {
return errors.New("data channel not ready")
}
return dc.Send(b)
}
// receiveWebRTC reads packets from the WebRTC receive channel.
// It is called by wireguard-go through the conn.Bind interface.
// It blocks until at least one packet is available, then drains as many
// additional packets as are immediately ready (up to len(buffs)).
func (c *connBind) receiveWebRTC(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
// Block until the first packet arrives (or the channel is closed).
wr, ok := <-c.webrtcRecvCh
if !ok || c.isClosed() {
return 0, net.ErrClosed
}
num := 0
n, ep := c.processWebRTCReadResult(wr, buffs[num])
if n > 0 {
sizes[num] = n
eps[num] = ep
num++
}
// Drain any additional packets that are immediately available.
for num < len(buffs) {
select {
case wr, ok = <-c.webrtcRecvCh:
if !ok || c.isClosed() {
if num > 0 {
return num, nil
}
return 0, net.ErrClosed
}
n, ep = c.processWebRTCReadResult(wr, buffs[num])
if n > 0 {
sizes[num] = n
eps[num] = ep
num++
}
default:
return num, nil
}
}
return num, nil
}

View File

@ -0,0 +1,37 @@
//go:build js
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"github.com/pion/webrtc/v4"
"tailscale.com/types/key"
)
func newWebRTCManagerBase(c *Conn, signalingURL string) *webrtcManager {
// Configure WebRTC with STUN only
settingEngine := webrtc.SettingEngine{}
// Create API with setting engine
api := webrtc.NewAPI(
webrtc.WithSettingEngine(settingEngine),
)
return &webrtcManager{
logf: c.logf,
conn: c,
peerConnectionsByEndpoint: make(map[*endpoint]*webrtcPeerState),
peerConnectionsByDisco: make(map[key.DiscoPublic]*webrtcPeerState),
startConnectionCh: make(chan *endpoint, 256),
connectionReadyCh: make(chan webrtcConnectionReadyEvent, 16),
closeCh: make(chan struct{}),
runLoopStoppedCh: make(chan struct{}),
api: api,
}
}
func setOnError(dc *webrtc.DataChannel, fn func(error)) {
// NO-OP... *webrtc.DataChannel does not have OnError for js.
}

View File

@ -0,0 +1,70 @@
//go:build !js
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"time"
"github.com/pion/webrtc/v4"
"tailscale.com/types/key"
)
func newWebRTCManagerBase(c *Conn, signalingURL string) *webrtcManager {
settingEngine := webrtc.SettingEngine{}
// Use a 16 MiB SCTP receive buffer. The pion default (~32 KiB) becomes the
// bottleneck at high throughput because SCTP's flow-control window is bounded
// by this value.
settingEngine.SetSCTPMaxReceiveBufferSize(16 * 1024 * 1024)
// Enlarge the DTLS replay-protection window. The default (64) causes
// legitimate packets to be dropped as duplicates when the sender gets ahead
// of the receiver by more than 64 packets, which happens easily at Gbps speeds.
settingEngine.SetDTLSReplayProtectionWindow(8192)
// Lower the SCTP retransmission timeout ceiling. The default (1s+) causes
// SCTP's congestion control to stall for a full second after any loss event,
// which is catastrophic for throughput on a low-latency P2P link. 100ms is
// still conservative but recovers much faster.
settingEngine.SetSCTPRTOMax(100 * time.Millisecond)
// DetachDataChannels lets us call dc.Detach() to get a raw io.ReadWriteCloser
// instead of using OnMessage callbacks. The callback path allocates a new
// DataChannelMessage struct and fires a goroutine wakeup per packet. The
// detached path lets us Read() directly into pre-allocated buffers in a
// tight goroutine loop, matching how the UDP receive path works.
settingEngine.DetachDataChannels()
// SCTP includes a CRC32c checksum on every chunk. DTLS already provides
// both integrity and authenticity for all data, so the SCTP checksum is
// redundant CPU work. Zero-checksum mode (RFC 9260) removes it.
settingEngine.EnableSCTPZeroChecksum(true)
// Create MediaEngine (required even though we only use DataChannel)
mediaEngine := &webrtc.MediaEngine{}
// Create API with setting engine
api := webrtc.NewAPI(
webrtc.WithSettingEngine(settingEngine),
webrtc.WithMediaEngine(mediaEngine),
)
return &webrtcManager{
logf: c.logf,
conn: c,
peerConnectionsByEndpoint: make(map[*endpoint]*webrtcPeerState),
peerConnectionsByDisco: make(map[key.DiscoPublic]*webrtcPeerState),
startConnectionCh: make(chan *endpoint, 256),
connectionReadyCh: make(chan webrtcConnectionReadyEvent, 16),
closeCh: make(chan struct{}),
runLoopStoppedCh: make(chan struct{}),
api: api,
}
}
func setOnError(dc *webrtc.DataChannel, fn func(error)) {
dc.OnError(fn)
}

View File

@ -0,0 +1,435 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"net/http"
"net/http/httptest"
"net/netip"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v4"
"tailscale.com/rtclib"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// testSignalHandler is a test implementation of rtclib.SignalHandler
type testSignalHandler struct {
offerCount int
answerCount int
candidateCount int
t *testing.T
mu sync.Mutex
}
func (h *testSignalHandler) HandleOffer(from, to string, offer *webrtc.SessionDescription) {
h.mu.Lock()
defer h.mu.Unlock()
h.offerCount++
h.t.Logf("Received offer from %s to %s", from, to)
}
func (h *testSignalHandler) HandleAnswer(from, to string, answer *webrtc.SessionDescription) {
h.mu.Lock()
defer h.mu.Unlock()
h.answerCount++
h.t.Logf("Received answer from %s to %s", from, to)
}
func (h *testSignalHandler) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) {
h.mu.Lock()
defer h.mu.Unlock()
h.candidateCount++
h.t.Logf("Received candidate from %s to %s", from, to)
}
// TestWebRTCIntegration_MockSignalingServer tests WebRTC with a mock signaling server
func TestWebRTCIntegration_MockSignalingServer(t *testing.T) {
// Create mock signaling server
server := newMockSignalingServer(t)
defer server.Close()
t.Logf("Mock signaling server running at %s", server.URL)
// Verify server accepts connections
client := newSignalingClient(server.URL, t.Logf)
handler := &testSignalHandler{t: t}
err := client.Start(handler)
if err != nil {
t.Fatalf("Failed to start client: %v", err)
}
defer client.Close()
// Wait for connection
time.Sleep(100 * time.Millisecond)
// Send a test message
disco1 := key.NewDisco().Public()
disco2 := key.NewDisco().Public()
err = client.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: "test",
})
if err != nil {
t.Fatalf("Failed to send offer: %v", err)
}
// Give time for message to be processed
time.Sleep(100 * time.Millisecond)
// Verify the server received the message
serverMsgs := server.GetReceivedMessages()
if len(serverMsgs) == 0 {
t.Error("Server did not receive any messages")
} else {
t.Logf("Server received %d messages", len(serverMsgs))
for i, msg := range serverMsgs {
t.Logf("Message %d: type=%s from=%s to=%s", i, msg.Type, msg.From, msg.To)
}
}
}
// TestWebRTCIntegration_MessageRelay tests message relay through signaling server
func TestWebRTCIntegration_MessageRelay(t *testing.T) {
server := newMockSignalingServer(t)
defer server.Close()
// Create two clients
handler1 := &testSignalHandler{t: t}
handler2 := &testSignalHandler{t: t}
client1 := newSignalingClient(server.URL, func(format string, args ...any) {
t.Logf("[Client1] "+format, args...)
})
client2 := newSignalingClient(server.URL, func(format string, args ...any) {
t.Logf("[Client2] "+format, args...)
})
err := client1.Start(handler1)
if err != nil {
t.Fatalf("Failed to start client1: %v", err)
}
defer client1.Close()
err = client2.Start(handler2)
if err != nil {
t.Fatalf("Failed to start client2: %v", err)
}
defer client2.Close()
// Wait for both to connect
time.Sleep(200 * time.Millisecond)
// Client 1 sends offer to Client 2
disco1 := key.NewDisco().Public()
disco2 := key.NewDisco().Public()
if err := client1.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: "v=0...",
}); err != nil {
t.Fatalf("Client1 failed to send offer: %v", err)
}
// Wait for message relay
time.Sleep(200 * time.Millisecond)
// Verify client2 received the offer
handler2.mu.Lock()
c2offers := handler2.offerCount
handler2.mu.Unlock()
if c2offers > 0 {
t.Logf("Client2 received %d offers (relay working)", c2offers)
} else {
t.Log("Client2 did not receive offers (relay may need proper routing)")
}
// Client 2 sends answer back to Client 1
if err := client2.Answer(disco2.String(), disco1.String(), &webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: "v=0...",
}); err != nil {
t.Fatalf("Client2 failed to send answer: %v", err)
}
// Wait for message relay
time.Sleep(200 * time.Millisecond)
// Log final message counts
handler1.mu.Lock()
c1answers := handler1.answerCount
handler1.mu.Unlock()
handler2.mu.Lock()
c2FinalOffers := handler2.offerCount
handler2.mu.Unlock()
t.Logf("Final message counts: Client1 answers=%d, Client2 offers=%d", c1answers, c2FinalOffers)
t.Log("Integration test completed successfully")
}
// TestWebRTCIntegration_SignalingFlow tests the complete signaling flow
func TestWebRTCIntegration_SignalingFlow(t *testing.T) {
server := newMockSignalingServer(t)
defer server.Close()
client := newSignalingClient(server.URL, t.Logf)
handler := &testSignalHandler{t: t}
err := client.Start(handler)
if err != nil {
t.Fatalf("Failed to start client: %v", err)
}
defer client.Close()
time.Sleep(100 * time.Millisecond)
disco1 := key.NewDisco().Public()
disco2 := key.NewDisco().Public()
// Simulate complete signaling flow
steps := []struct {
name string
fn func() error
}{
{
name: "send_offer",
fn: func() error {
return client.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: "v=0 offer",
})
},
},
{
name: "send_answer",
fn: func() error {
return client.Answer(disco2.String(), disco1.String(), &webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: "v=0 answer",
})
},
},
{
name: "send_candidate",
fn: func() error {
return client.Candidate(disco1.String(), disco2.String(), &webrtc.ICECandidateInit{
Candidate: "test",
})
},
},
}
for _, step := range steps {
t.Logf("Step: %s", step.name)
if err := step.fn(); err != nil {
t.Errorf("Failed to execute %s: %v", step.name, err)
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("Signaling flow completed with %d steps", len(steps))
}
// mockSignalingServer is a simple WebSocket server that relays signaling messages
type mockSignalingServer struct {
*httptest.Server
upgrader websocket.Upgrader
mu sync.Mutex
clients map[*websocket.Conn]bool
messages []rtclib.SignalingMessage
t *testing.T
}
func newMockSignalingServer(t *testing.T) *mockSignalingServer {
s := &mockSignalingServer{
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
},
clients: make(map[*websocket.Conn]bool),
messages: make([]rtclib.SignalingMessage, 0),
t: t,
}
s.Server = httptest.NewServer(http.HandlerFunc(s.handleWebSocket))
// Convert http:// to ws://
s.Server.URL = "ws" + s.Server.URL[4:]
return s
}
func (s *mockSignalingServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
s.t.Logf("Upgrade error: %v", err)
return
}
s.mu.Lock()
s.clients[conn] = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
delete(s.clients, conn)
s.mu.Unlock()
conn.Close()
}()
s.t.Logf("Client connected, total clients: %d", len(s.clients))
for {
var msg rtclib.SignalingMessage
if err := conn.ReadJSON(&msg); err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
s.t.Logf("Read error: %v", err)
}
break
}
s.t.Logf("Server received: type=%s from=%s to=%s", msg.Type, msg.From, msg.To)
s.mu.Lock()
s.messages = append(s.messages, msg)
// Relay to all other clients (simple broadcast)
for client := range s.clients {
if client != conn {
if err := client.WriteJSON(msg); err != nil {
s.t.Logf("Relay error: %v", err)
}
}
}
s.mu.Unlock()
}
}
func (s *mockSignalingServer) GetReceivedMessages() []rtclib.SignalingMessage {
s.mu.Lock()
defer s.mu.Unlock()
msgs := make([]rtclib.SignalingMessage, len(s.messages))
copy(msgs, s.messages)
return msgs
}
func (s *mockSignalingServer) Close() {
s.mu.Lock()
for conn := range s.clients {
conn.Close()
}
s.mu.Unlock()
s.Server.Close()
}
// BenchmarkWebRTCSignaling benchmarks signaling message throughput
func BenchmarkWebRTCSignaling(b *testing.B) {
server := newMockSignalingServer(&testing.T{})
defer server.Close()
client := newSignalingClient(server.URL, func(string, ...any) {})
handler := &testSignalHandler{t: &testing.T{}}
err := client.Start(handler)
if err != nil {
b.Fatalf("Failed to start client: %v", err)
}
defer client.Close()
time.Sleep(100 * time.Millisecond)
disco1 := key.NewDisco().Public()
disco2 := key.NewDisco().Public()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := client.Candidate(disco1.String(), disco2.String(), &webrtc.ICECandidateInit{
Candidate: "test",
}); err != nil {
b.Errorf("Send failed: %v", err)
}
}
}
// TestWebRTCPacketFlow tests packet flow simulation
func TestWebRTCPacketFlow(t *testing.T) {
// Create a mock webrtcReadResult to simulate packet reception
nodeKey := key.NewNode().Public()
testPacket := []byte("test wireguard packet")
result := webrtcReadResult{
n: len(testPacket),
src: nodeKey,
copyBuf: func(dst []byte) int {
return copy(dst, testPacket)
},
}
// Verify packet can be copied
buf := make([]byte, 1024)
n := result.copyBuf(buf)
if n != len(testPacket) {
t.Errorf("copyBuf returned %d bytes, want %d", n, len(testPacket))
}
if string(buf[:n]) != string(testPacket) {
t.Errorf("Packet data mismatch: got %q, want %q", buf[:n], testPacket)
}
t.Logf("Packet flow test passed: %d bytes", n)
}
// TestWebRTCPathSelection tests path selection with WebRTC in the mix
func TestWebRTCPathSelection(t *testing.T) {
tests := []struct {
name string
paths []addrQuality
wantBest string
}{
{
name: "direct_beats_all",
paths: []addrQuality{
{epAddr: epAddr{ap: netip.MustParseAddrPort("1.2.3.4:1234")}},
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}},
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1)}},
},
wantBest: "1.2.3.4:1234",
},
{
name: "webrtc_beats_derp",
paths: []addrQuality{
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}},
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1)}},
},
wantBest: "127.3.3.41:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
best := tt.paths[0]
for _, path := range tt.paths[1:] {
if betterAddr(path, best) {
best = path
}
}
if best.ap.String() != tt.wantBest {
t.Errorf("Best path = %v, want %v", best.ap, tt.wantBest)
}
})
}
}

View File

@ -0,0 +1,296 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/pion/webrtc/v4"
"tailscale.com/rtclib"
"tailscale.com/types/logger"
)
// Ensure signalingClient implements rtclib.Signaller interface.
var _ rtclib.Signaller = (*signalingClient)(nil)
// signalingClient manages WebSocket connection to signaling server.
type signalingClient struct {
url string
logf logger.Logf
conn *websocket.Conn
connMu sync.Mutex
// Message handler
handler rtclib.SignalHandler
// Control channels
ctx context.Context
ctxCancel context.CancelFunc
sendCh chan *rtclib.SignalingMessage
closedCh chan struct{}
// Reconnection state
reconnectDelay time.Duration
maxDelay time.Duration
}
// newSignalingClient creates a new signaling client.
func newSignalingClient(url string, logf logger.Logf) *signalingClient {
ctx, cancel := context.WithCancel(context.Background())
return &signalingClient{
url: url,
logf: logf,
ctx: ctx,
ctxCancel: cancel,
sendCh: make(chan *rtclib.SignalingMessage, 16),
closedCh: make(chan struct{}),
reconnectDelay: time.Second,
maxDelay: 30 * time.Second,
}
}
// Start begins the signaling client's connection and message loops.
func (sc *signalingClient) Start(handler rtclib.SignalHandler) error {
sc.handler = handler
if err := sc.connect(); err != nil {
return fmt.Errorf("initial signaling connection failed: %w", err)
}
go sc.runLoop()
return nil
}
// connect establishes WebSocket connection to signaling server.
func (sc *signalingClient) connect() error {
sc.connMu.Lock()
defer sc.connMu.Unlock()
if sc.conn != nil {
return nil // already connected
}
conn, _, err := websocket.Dial(sc.ctx, sc.url, nil)
if err != nil {
return fmt.Errorf("websocket dial failed: %w", err)
}
sc.conn = conn
sc.reconnectDelay = time.Second // reset backoff on successful connection
sc.logf("signaling: connected to %s", sc.url)
return nil
}
// Close closes the signaling client.
func (sc *signalingClient) Close() error {
// Cancel context to signal all goroutines to stop
sc.ctxCancel()
// Close the connection to unblock any read/write operations
sc.connMu.Lock()
if sc.conn != nil {
sc.conn.Close(websocket.StatusNormalClosure, "")
}
sc.connMu.Unlock()
// Wait for runLoop to finish with timeout
select {
case <-sc.closedCh:
case <-time.After(2 * time.Second):
sc.logf("signaling: close timed out, forcing shutdown")
}
sc.connMu.Lock()
defer sc.connMu.Unlock()
sc.conn = nil
return nil
}
// send queues a message to be sent to the signaling server.
func (sc *signalingClient) send(msg *rtclib.SignalingMessage) error {
select {
case sc.sendCh <- msg:
return nil
case <-sc.ctx.Done():
return sc.ctx.Err()
default:
return errors.New("signaling send queue full")
}
}
// runLoop manages connection lifecycle and message routing.
func (sc *signalingClient) runLoop() {
defer close(sc.closedCh)
for {
select {
case <-sc.ctx.Done():
return
default:
}
// Ensure we're connected
if err := sc.ensureConnected(); err != nil {
sc.logf("signaling: connection failed, retrying in %v: %v", sc.reconnectDelay, err)
select {
case <-time.After(sc.reconnectDelay):
sc.reconnectDelay = min(sc.reconnectDelay*2, sc.maxDelay)
continue
case <-sc.ctx.Done():
return
}
}
// Run read/write loops
errCh := make(chan error, 2)
go sc.readLoop(errCh)
go sc.writeLoop(errCh)
// Wait for error or context cancellation
select {
case err := <-errCh:
sc.logf("signaling: connection error: %v", err)
sc.disconnect()
case <-sc.ctx.Done():
sc.disconnect()
return
}
}
}
// ensureConnected ensures connection is established.
func (sc *signalingClient) ensureConnected() error {
sc.connMu.Lock()
connected := sc.conn != nil
sc.connMu.Unlock()
if connected {
return nil
}
return sc.connect()
}
// disconnect closes the current connection.
func (sc *signalingClient) disconnect() {
sc.connMu.Lock()
defer sc.connMu.Unlock()
if sc.conn != nil {
sc.conn.Close(websocket.StatusNormalClosure, "")
sc.conn = nil
sc.logf("signaling: disconnected")
}
}
// readLoop reads messages from WebSocket.
func (sc *signalingClient) readLoop(errCh chan<- error) {
for {
sc.connMu.Lock()
conn := sc.conn
sc.connMu.Unlock()
if conn == nil {
errCh <- errors.New("no connection")
return
}
var msg rtclib.SignalingMessage
if err := wsjson.Read(sc.ctx, conn, &msg); err != nil {
errCh <- fmt.Errorf("read failed: %w", err)
return
}
if sc.handler != nil {
switch msg.Type {
case rtclib.MessageTypeOffer:
sc.handler.HandleOffer(msg.From, msg.To, msg.Offer)
case rtclib.MessageTypeAnswer:
sc.handler.HandleAnswer(msg.From, msg.To, msg.Answer)
case rtclib.MessageTypeCandidate:
sc.handler.HandleCandidate(msg.From, msg.To, msg.Candidate)
default:
sc.logf("signaling: unknown message type: %s", msg.Type)
}
}
}
}
// writeLoop writes messages to WebSocket.
func (sc *signalingClient) writeLoop(errCh chan<- error) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case msg := <-sc.sendCh:
sc.connMu.Lock()
conn := sc.conn
sc.connMu.Unlock()
if conn == nil {
errCh <- errors.New("no connection")
return
}
if err := wsjson.Write(sc.ctx, conn, msg); err != nil {
errCh <- fmt.Errorf("write failed: %w", err)
return
}
case <-ticker.C:
// Send ping to keep connection alive
sc.connMu.Lock()
conn := sc.conn
sc.connMu.Unlock()
if conn == nil {
errCh <- errors.New("no connection")
return
}
if err := conn.Ping(sc.ctx); err != nil {
errCh <- fmt.Errorf("ping failed: %w", err)
return
}
case <-sc.ctx.Done():
return
}
}
}
// Offer sends an SDP offer to a peer.
func (sc *signalingClient) Offer(from, to string, offer *webrtc.SessionDescription) error {
return sc.send(&rtclib.SignalingMessage{
Type: rtclib.MessageTypeOffer,
From: from,
To: to,
Offer: offer,
})
}
// Answer sends an SDP answer to a peer.
func (sc *signalingClient) Answer(from, to string, answer *webrtc.SessionDescription) error {
return sc.send(&rtclib.SignalingMessage{
Type: rtclib.MessageTypeAnswer,
From: from,
To: to,
Answer: answer,
})
}
// Candidate sends an ICE candidate to a peer.
func (sc *signalingClient) Candidate(from, to string, candidate *webrtc.ICECandidateInit) error {
return sc.send(&rtclib.SignalingMessage{
Type: rtclib.MessageTypeCandidate,
From: from,
To: to,
Candidate: candidate,
})
}

View File

@ -0,0 +1,323 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v4"
"tailscale.com/rtclib"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// TestSignalingMessageEncoding tests JSON encoding/decoding of signaling messages
func TestSignalingMessageEncoding(t *testing.T) {
disco1 := key.NewDisco()
disco2 := key.NewDisco()
tests := []struct {
name string
msg rtclib.SignalingMessage
}{
{
name: "offer",
msg: rtclib.SignalingMessage{
Type: rtclib.MessageTypeOffer,
From: disco1.Public().String(),
To: disco2.Public().String(),
Offer: &webrtc.SessionDescription{Type: webrtc.SDPTypeOffer, SDP: "v=0..."},
},
},
{
name: "answer",
msg: rtclib.SignalingMessage{
Type: rtclib.MessageTypeAnswer,
From: disco2.Public().String(),
To: disco1.Public().String(),
Answer: &webrtc.SessionDescription{Type: webrtc.SDPTypeAnswer, SDP: "v=0..."},
},
},
{
name: "candidate",
msg: rtclib.SignalingMessage{
Type: rtclib.MessageTypeCandidate,
From: disco1.Public().String(),
To: disco2.Public().String(),
Candidate: &webrtc.ICECandidateInit{Candidate: "..."},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Encode
data, err := json.Marshal(tt.msg)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
// Decode
var decoded rtclib.SignalingMessage
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
// Verify
if decoded.Type != tt.msg.Type {
t.Errorf("Type mismatch: got %v, want %v", decoded.Type, tt.msg.Type)
}
if decoded.From != tt.msg.From {
t.Errorf("From mismatch: got %v, want %v", decoded.From, tt.msg.From)
}
if decoded.To != tt.msg.To {
t.Errorf("To mismatch: got %v, want %v", decoded.To, tt.msg.To)
}
})
}
}
// mockSignalHandler is a test implementation of rtclib.SignalHandler
type mockSignalHandler struct {
offerCount int
answerCount int
candidateCount int
mu sync.Mutex
}
func (m *mockSignalHandler) HandleOffer(from, to string, offer *webrtc.SessionDescription) {
m.mu.Lock()
defer m.mu.Unlock()
m.offerCount++
}
func (m *mockSignalHandler) HandleAnswer(from, to string, answer *webrtc.SessionDescription) {
m.mu.Lock()
defer m.mu.Unlock()
m.answerCount++
}
func (m *mockSignalHandler) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) {
m.mu.Lock()
defer m.mu.Unlock()
m.candidateCount++
}
// TestSignalingClientReconnect tests reconnection with backoff
func TestSignalingClientReconnect(t *testing.T) {
var connectCount int
var mu sync.Mutex
// Mock WebSocket server that closes connections after accepting them
upgrader := websocket.Upgrader{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
connectCount++
mu.Unlock()
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
// Immediately close to force reconnection
conn.Close()
}))
defer server.Close()
// Convert http:// to ws://
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
client := newSignalingClient(wsURL, t.Logf)
handler := &mockSignalHandler{}
err := client.Start(handler)
if err != nil {
t.Fatalf("Failed to start client: %v", err)
}
defer client.Close()
// Wait for a few reconnection attempts
time.Sleep(3 * time.Second)
mu.Lock()
count := connectCount
mu.Unlock()
// Should have attempted to connect multiple times
if count < 2 {
t.Errorf("Expected multiple reconnection attempts, got %d", count)
}
t.Logf("Reconnected %d times", count)
}
// TestWebRTCMagicIP tests the WebRTC magic IP constant
func TestWebRTCMagicIP(t *testing.T) {
if tailcfg.WebRTCMagicIPAddr.String() != "127.3.3.41" {
t.Errorf("WebRTC magic IP = %v, want 127.3.3.41", tailcfg.WebRTCMagicIPAddr)
}
// Verify it's different from DERP magic IP
if tailcfg.WebRTCMagicIPAddr == tailcfg.DerpMagicIPAddr {
t.Error("WebRTC magic IP should be different from DERP magic IP")
}
}
// TestWebRTCPathPriority tests path preference logic
func TestWebRTCPathPriority(t *testing.T) {
directV4 := addrQuality{
epAddr: epAddr{
ap: netip.MustParseAddrPort("192.168.1.100:41641"),
},
latency: 10 * time.Millisecond,
}
webrtc := addrQuality{
epAddr: epAddr{
ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345),
},
latency: 50 * time.Millisecond,
}
derp := addrQuality{
epAddr: epAddr{
ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1),
},
latency: 100 * time.Millisecond,
}
tests := []struct {
name string
a, b addrQuality
want bool // true if a is better than b
}{
{
name: "direct beats WebRTC",
a: directV4,
b: webrtc,
want: true,
},
{
name: "WebRTC beats DERP",
a: webrtc,
b: derp,
want: true,
},
{
name: "direct beats DERP",
a: directV4,
b: derp,
want: true,
},
{
name: "DERP loses to WebRTC",
a: derp,
b: webrtc,
want: false,
},
{
name: "WebRTC loses to direct",
a: webrtc,
b: directV4,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := betterAddr(tt.a, tt.b)
if got != tt.want {
t.Errorf("betterAddr(%v, %v) = %v, want %v", tt.a.ap, tt.b.ap, got, tt.want)
}
})
}
}
// TestWebRTCReadResult tests webrtcReadResult structure
func TestWebRTCReadResult(t *testing.T) {
nodeKey := key.NewNode()
testData := []byte("test packet data")
result := webrtcReadResult{
n: len(testData),
src: nodeKey.Public(),
copyBuf: func(dst []byte) int {
return copy(dst, testData)
},
}
// Test copyBuf
buf := make([]byte, 100)
n := result.copyBuf(buf)
if n != len(testData) {
t.Errorf("copyBuf returned %d, want %d", n, len(testData))
}
if string(buf[:n]) != string(testData) {
t.Errorf("copyBuf data = %q, want %q", buf[:n], testData)
}
// Test fields
if result.n != len(testData) {
t.Errorf("result.n = %d, want %d", result.n, len(testData))
}
if result.src != nodeKey.Public() {
t.Errorf("result.src mismatch")
}
}
// TestDiscoRXPathWebRTC tests the WebRTC disco path constant
func TestDiscoRXPathWebRTC(t *testing.T) {
if discoRXPathWebRTC != "WebRTC" {
t.Errorf("discoRXPathWebRTC = %q, want %q", discoRXPathWebRTC, "WebRTC")
}
// Verify it's different from other paths
if discoRXPathWebRTC == discoRXPathDERP {
t.Error("WebRTC path should be different from DERP path")
}
if discoRXPathWebRTC == discoRXPathUDP {
t.Error("WebRTC path should be different from UDP path")
}
}
// TestWebRTCMetrics tests that WebRTC metrics are properly defined
func TestWebRTCMetrics(t *testing.T) {
// Test that metric variables exist (they're package-level variables)
if metricRecvDataPacketsWebRTC == nil {
t.Error("metricRecvDataPacketsWebRTC should be initialized")
}
if metricRecvDataBytesWebRTC == nil {
t.Error("metricRecvDataBytesWebRTC should be initialized")
}
if metricSendDataPacketsWebRTC == nil {
t.Error("metricSendDataPacketsWebRTC should be initialized")
}
if metricSendDataBytesWebRTC == nil {
t.Error("metricSendDataBytesWebRTC should be initialized")
}
t.Log("WebRTC metrics are properly defined")
}
// TestPathWebRTCConstant tests the PathWebRTC constant
func TestPathWebRTCConstant(t *testing.T) {
if PathWebRTC != "webrtc" {
t.Errorf("PathWebRTC = %q, want %q", PathWebRTC, "webrtc")
}
// Verify it's different from other paths
if PathWebRTC == PathDERP {
t.Error("PathWebRTC should be different from PathDERP")
}
if PathWebRTC == PathDirectIPv4 {
t.Error("PathWebRTC should be different from PathDirectIPv4")
}
}