From 413ba3863269278a9c3622146f6864a2ef2ba01b Mon Sep 17 00:00:00 2001 From: Adriano Sela Aviles Date: Sat, 11 Apr 2026 12:01:09 -0700 Subject: [PATCH] wgengine/magicsock: add webrtc path to magicsock (experimental) --- cmd/tailscale/cli/status.go | 15 +- example-webrtc-server/go.mod | 5 + example-webrtc-server/go.sum | 2 + example-webrtc-server/main.go | 466 +++++++++++ go.mod | 19 +- go.sum | 34 + ipn/ipnstate/ipnstate.go | 13 +- rtclib/signaling.go | 53 ++ tailcfg/tailcfg.go | 8 + wgengine/magicsock/debugknobs.go | 7 + wgengine/magicsock/debugknobs_stubs.go | 3 + wgengine/magicsock/endpoint.go | 79 +- wgengine/magicsock/magicsock.go | 272 ++++++- wgengine/magicsock/webrtc.go | 747 ++++++++++++++++++ wgengine/magicsock/webrtc_base_js.go | 37 + wgengine/magicsock/webrtc_base_native.go | 70 ++ wgengine/magicsock/webrtc_integration_test.go | 435 ++++++++++ wgengine/magicsock/webrtc_signaling.go | 296 +++++++ wgengine/magicsock/webrtc_test.go | 323 ++++++++ 19 files changed, 2858 insertions(+), 26 deletions(-) create mode 100644 example-webrtc-server/go.mod create mode 100644 example-webrtc-server/go.sum create mode 100644 example-webrtc-server/main.go create mode 100644 rtclib/signaling.go create mode 100644 wgengine/magicsock/webrtc.go create mode 100644 wgengine/magicsock/webrtc_base_js.go create mode 100644 wgengine/magicsock/webrtc_base_native.go create mode 100644 wgengine/magicsock/webrtc_integration_test.go create mode 100644 wgengine/magicsock/webrtc_signaling.go create mode 100644 wgengine/magicsock/webrtc_test.go diff --git a/cmd/tailscale/cli/status.go b/cmd/tailscale/cli/status.go index 9ce4debda..8191c7759 100644 --- a/cmd/tailscale/cli/status.go +++ b/cmd/tailscale/cli/status.go @@ -23,6 +23,7 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/netmon" + "tailscale.com/tailcfg" "tailscale.com/util/dnsname" ) @@ -196,7 +197,19 @@ func runStatus(ctx context.Context, args []string) error { if relay != "" && ps.CurAddr == "" && ps.PeerRelay == "" { f("relay %q", relay) } else if ps.CurAddr != "" { - f("direct %s", ps.CurAddr) + // Check if this is a WebRTC connection (address matches WebRTC magic IP) + if strings.HasPrefix(ps.CurAddr, tailcfg.WebRTCMagicIP) { + // Extract the actual remote address from CurAddr, which for a WebRTC path + // is of the form "${WEBRTC_MAGIC_IP}:${DUMMY_PORT} (${REAL_IP_AND_PORT})" + // e.g. "127.3.3.41:12345 (134.209.53.229:37792)". + realRemoteAddr := "" + if idx := strings.Index(ps.CurAddr, " ("); idx > 0 { + realRemoteAddr = ps.CurAddr[idx+2 : len(ps.CurAddr)-1] + } + f("webrtc %s", realRemoteAddr) + } else { + f("direct %s", ps.CurAddr) + } } else if ps.PeerRelay != "" { f("peer-relay %s", ps.PeerRelay) } diff --git a/example-webrtc-server/go.mod b/example-webrtc-server/go.mod new file mode 100644 index 000000000..58c6300ef --- /dev/null +++ b/example-webrtc-server/go.mod @@ -0,0 +1,5 @@ +module github.com/adrianosela/tailscale/example-webrtc-server + +go 1.26 + +require github.com/gorilla/websocket v1.5.3 diff --git a/example-webrtc-server/go.sum b/example-webrtc-server/go.sum new file mode 100644 index 000000000..25a9fc4bb --- /dev/null +++ b/example-webrtc-server/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/example-webrtc-server/main.go b/example-webrtc-server/main.go new file mode 100644 index 000000000..bab7b83d3 --- /dev/null +++ b/example-webrtc-server/main.go @@ -0,0 +1,466 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// example-webrtc-server is a WebRTC signaling server that supports both +// WebSocket (for Tailscale) and HTTP REST (for standard WebRTC clients). +package main + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// SignalingMessage represents a WebRTC signaling message. +// This format is compatible with both Tailscale and standard WebRTC clients. +type SignalingMessage struct { + Type string `json:"type"` // "offer", "answer", "candidate" + From string `json:"from"` // sender's disco public key (hex) + To string `json:"to"` // recipient's disco public key (hex) + + // For SDP offer/answer (raw JSON for flexibility) + Offer json.RawMessage `json:"offer,omitempty"` + Answer json.RawMessage `json:"answer,omitempty"` + Candidate json.RawMessage `json:"candidate,omitempty"` + + // Legacy fields for HTTP REST clients + SDP string `json:"sdp,omitempty"` // Used by non-Tailscale clients + + Timestamp time.Time `json:"timestamp"` +} + +// Client represents a connected peer (WebSocket or HTTP polling) +type Client struct { + ID string + Conn *websocket.Conn // nil for HTTP clients + LastSeen time.Time +} + +// SignalingServer manages WebRTC signaling between peers +type SignalingServer struct { + mu sync.RWMutex + clients map[string]*Client // Active WebSocket clients + + // Message queue for HTTP polling clients + messages map[string][]SignalingMessage // Key: "to" peer ID + + upgrader websocket.Upgrader + + // Statistics + stats struct { + totalMessages int + wsConnections int + httpPolls int + activeOffers int + completedPairs int + } +} + +// NewSignalingServer creates a new signaling server +func NewSignalingServer() *SignalingServer { + return &SignalingServer{ + clients: make(map[string]*Client), + messages: make(map[string][]SignalingMessage), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins (configure as needed) + }, + }, + } +} + +// RouteMessage delivers a message to the destination peer +func (s *SignalingServer) RouteMessage(msg SignalingMessage, clientIP string) error { + msg.Timestamp = time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + // Update stats + s.stats.totalMessages++ + switch msg.Type { + case "offer": + s.stats.activeOffers++ + // Clear old messages from this sender when starting a new session + s.clearOldMessages(msg.From, msg.To) + case "answer": + s.stats.completedPairs++ + if s.stats.activeOffers > 0 { + s.stats.activeOffers-- + } + } + + log.Printf("[%s] Routing %s from %s to %s", clientIP, msg.Type, msg.From, msg.To) + + // Try to deliver to WebSocket client first + if client, ok := s.clients[msg.To]; ok && client.Conn != nil { + // Send directly via WebSocket + if err := client.Conn.WriteJSON(msg); err != nil { + log.Printf("[%s] Failed to send to WebSocket client %s: %v", clientIP, msg.To, err) + // Remove dead connection + delete(s.clients, msg.To) + // Fall through to queue message + } else { + log.Printf("[%s] Delivered %s to WebSocket client %s", clientIP, msg.Type, msg.To) + return nil + } + } + + // Queue for HTTP polling client + s.messages[msg.To] = append(s.messages[msg.To], msg) + log.Printf("[%s] Queued %s for HTTP client %s", clientIP, msg.Type, msg.To) + + return nil +} + +// clearOldMessages removes previous messages between two peers (used when new session starts) +func (s *SignalingServer) clearOldMessages(from, to string) { + if msgs, ok := s.messages[to]; ok { + filtered := make([]SignalingMessage, 0) + for _, msg := range msgs { + if msg.From != from { + filtered = append(filtered, msg) + } + } + if len(filtered) == 0 { + delete(s.messages, to) + } else { + s.messages[to] = filtered + } + } +} + +// GetMessages retrieves queued messages for an HTTP polling client +func (s *SignalingServer) GetMessages(peerID, clientIP string) []SignalingMessage { + s.mu.Lock() + defer s.mu.Unlock() + + s.stats.httpPolls++ + + messages := s.messages[peerID] + if len(messages) == 0 { + return nil + } + + // Return all messages and clear the queue + delete(s.messages, peerID) + + log.Printf("[%s] Delivering %d queued message(s) to HTTP client %s", clientIP, len(messages), peerID) + return messages +} + +// CleanupOldMessages removes stale messages and dead connections +func (s *SignalingServer) CleanupOldMessages(maxAge time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + + cutoff := time.Now().Add(-maxAge) + cleaned := 0 + + // Clean old messages + for peerID, messages := range s.messages { + filtered := make([]SignalingMessage, 0, len(messages)) + for _, msg := range messages { + if msg.Timestamp.After(cutoff) { + filtered = append(filtered, msg) + } else { + cleaned++ + } + } + + if len(filtered) == 0 { + delete(s.messages, peerID) + } else { + s.messages[peerID] = filtered + } + } + + // Clean inactive clients + for id, client := range s.clients { + if time.Since(client.LastSeen) > maxAge { + if client.Conn != nil { + client.Conn.Close() + } + delete(s.clients, id) + cleaned++ + } + } + + if cleaned > 0 { + log.Printf("Cleaned up %d old messages/connections", cleaned) + } +} + +// GetStats returns current server statistics +func (s *SignalingServer) GetStats() map[string]interface{} { + s.mu.RLock() + defer s.mu.RUnlock() + + queuedMessages := 0 + for _, msgs := range s.messages { + queuedMessages += len(msgs) + } + + return map[string]interface{}{ + "total_messages": s.stats.totalMessages, + "ws_connections": s.stats.wsConnections, + "http_polls": s.stats.httpPolls, + "active_offers": s.stats.activeOffers, + "completed_pairs": s.stats.completedPairs, + "queued_messages": queuedMessages, + "active_ws_clients": len(s.clients), + "active_peer_ids": len(s.messages), + } +} + +// WebSocket Handler (for Tailscale clients) +func (s *SignalingServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade failed: %v", err) + return + } + defer conn.Close() + + var clientID string + + s.mu.Lock() + s.stats.wsConnections++ + s.mu.Unlock() + + defer func() { + s.mu.Lock() + s.stats.wsConnections-- + s.mu.Unlock() + }() + + log.Printf("[%s] New WebSocket connection", r.RemoteAddr) + + // Read and route messages from this WebSocket client + for { + var msg SignalingMessage + if err := conn.ReadJSON(&msg); err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("[%s] WebSocket error: %v", r.RemoteAddr, err) + } + break + } + + // Register client on first message + if clientID == "" { + clientID = msg.From + s.mu.Lock() + s.clients[clientID] = &Client{ + ID: clientID, + Conn: conn, + LastSeen: time.Now(), + } + s.mu.Unlock() + log.Printf("[%s] WebSocket client registered as %s", r.RemoteAddr, clientID) + } + + // Update last seen + s.mu.Lock() + if client, ok := s.clients[clientID]; ok { + client.LastSeen = time.Now() + } + s.mu.Unlock() + + // Route the message to destination + if err := s.RouteMessage(msg, r.RemoteAddr); err != nil { + log.Printf("[%s] Failed to route message: %v", r.RemoteAddr, err) + } + } + + // Cleanup on disconnect + if clientID != "" { + s.mu.Lock() + delete(s.clients, clientID) + s.mu.Unlock() + log.Printf("[%s] WebSocket client %s disconnected", r.RemoteAddr, clientID) + } +} + +// HTTP REST Handlers (for standard WebRTC clients) + +func (s *SignalingServer) handlePostSignal(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var msg SignalingMessage + if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if msg.From == "" || msg.To == "" || msg.Type == "" { + http.Error(w, "Missing required fields: from, to, type", http.StatusBadRequest) + return + } + + if msg.Type != "offer" && msg.Type != "answer" && msg.Type != "candidate" { + http.Error(w, "Invalid type, must be 'offer', 'answer', or 'candidate'", http.StatusBadRequest) + return + } + + // Convert legacy SDP field to Offer/Answer format for compatibility + if msg.SDP != "" { + sdpJSON := json.RawMessage(fmt.Sprintf(`{"type":"%s","sdp":%q}`, msg.Type, msg.SDP)) + if msg.Type == "offer" { + msg.Offer = sdpJSON + } else if msg.Type == "answer" { + msg.Answer = sdpJSON + } + } + + if err := s.RouteMessage(msg, r.RemoteAddr); err != nil { + http.Error(w, fmt.Sprintf("Failed to route message: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]string{ + "status": "ok", + "message": fmt.Sprintf("Message routed to %s", msg.To), + }) +} + +func (s *SignalingServer) handleGetSignal(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + to := r.URL.Query().Get("to") + if to == "" { + http.Error(w, "Missing required query parameter: to", http.StatusBadRequest) + return + } + + messages := s.GetMessages(to, r.RemoteAddr) + + w.Header().Set("Content-Type", "application/json") + + if len(messages) == 0 { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]string{ + "status": "not_found", + "message": fmt.Sprintf("No messages for %s", to), + }) + return + } + + w.WriteHeader(http.StatusOK) + // Return first message (client should poll again for more) + json.NewEncoder(w).Encode(messages[0]) +} + +func (s *SignalingServer) handleStats(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(s.GetStats()) +} + +func (s *SignalingServer) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "status": "healthy", + "service": "webrtc-signaling-server", + "version": "1.0.0", + }) +} + +// CORS middleware +func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + next(w, r) + } +} + +// Logging middleware +func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next(w, r) + log.Printf("%s %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, time.Since(start)) + } +} + +func main() { + port := flag.Int("port", 8080, "Port to listen on") + tlsCert := flag.String("cert", "", "TLS certificate file (optional, for HTTPS)") + tlsKey := flag.String("key", "", "TLS key file (optional, for HTTPS)") + cleanupInterval := flag.Duration("cleanup", 5*time.Minute, "Interval for cleaning up old messages") + messageMaxAge := flag.Duration("max-age", 10*time.Minute, "Maximum age for messages before cleanup") + flag.Parse() + + server := NewSignalingServer() + + // Start cleanup goroutine + go func() { + ticker := time.NewTicker(*cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + server.CleanupOldMessages(*messageMaxAge) + } + }() + + // Register handlers + // WebSocket endpoint (for Tailscale) + http.HandleFunc("/ws", loggingMiddleware(server.handleWebSocket)) + + // HTTP REST endpoints (for standard WebRTC clients) + http.HandleFunc("/signal", corsMiddleware(loggingMiddleware(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + server.handlePostSignal(w, r) + } else if r.Method == http.MethodGet { + server.handleGetSignal(w, r) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }))) + + // Monitoring endpoints + http.HandleFunc("/stats", corsMiddleware(loggingMiddleware(server.handleStats))) + http.HandleFunc("/health", corsMiddleware(loggingMiddleware(server.handleHealth))) + + addr := fmt.Sprintf(":%d", *port) + + log.Printf("Server starting on %s", addr) + log.Printf("WebSocket endpoint: ws://localhost%s/ws", addr) + log.Printf("HTTP REST endpoint: http://localhost%s/signal", addr) + log.Printf("Cleanup interval: %v, Max message age: %v", *cleanupInterval, *messageMaxAge) + log.Println("────────────────────────────────────────────────────────────") + + var err error + if *tlsCert != "" && *tlsKey != "" { + log.Printf("Starting HTTPS server with TLS...") + err = http.ListenAndServeTLS(addr, *tlsCert, *tlsKey, nil) + } else { + log.Printf("Starting HTTP server (use -cert and -key for HTTPS)...") + err = http.ListenAndServe(addr, nil) + } + + if err != nil { + log.Fatal(err) + } +} diff --git a/go.mod b/go.mod index bdd713a30..1efe4eb15 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 github.com/google/uuid v1.6.0 github.com/goreleaser/nfpm/v2 v2.33.1 + github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 github.com/hashicorp/go-hclog v1.6.2 github.com/hashicorp/raft v1.7.2 github.com/hashicorp/raft-boltdb/v2 v2.3.1 @@ -79,6 +80,7 @@ require ( github.com/miekg/dns v1.1.58 github.com/mitchellh/go-ps v1.0.0 github.com/peterbourgon/ff/v3 v3.4.0 + github.com/pion/webrtc/v4 v4.0.0 github.com/pires/go-proxyproto v0.8.1 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.6 @@ -198,7 +200,6 @@ require ( github.com/google/renameio/v2 v2.0.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/gorilla/securecookie v1.1.2 // indirect - github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/gosuri/uitable v0.0.4 // indirect github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect @@ -230,6 +231,21 @@ require ( github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect + github.com/pion/datachannel v1.5.9 // indirect + github.com/pion/dtls/v3 v3.0.3 // indirect + github.com/pion/ice/v4 v4.0.2 // indirect + github.com/pion/interceptor v0.1.37 // indirect + github.com/pion/logging v0.2.2 // indirect + github.com/pion/mdns/v2 v2.0.7 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtcp v1.2.14 // indirect + github.com/pion/rtp v1.8.9 // indirect + github.com/pion/sctp v1.8.33 // indirect + github.com/pion/sdp/v3 v3.0.9 // indirect + github.com/pion/srtp/v3 v3.0.4 // indirect + github.com/pion/stun/v3 v3.0.0 // indirect + github.com/pion/transport/v3 v3.0.7 // indirect + github.com/pion/turn/v4 v4.0.0 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/puzpuzpuz/xsync v1.5.2 // indirect github.com/rtr7/dhcp4 v0.0.0-20220302171438-18c84d089b46 // indirect @@ -239,6 +255,7 @@ require ( github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect github.com/stacklok/frizbee v0.1.7 // indirect github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 // indirect + github.com/wlynxg/anet v0.0.3 // indirect github.com/xen0n/gosmopolitan v1.2.2 // indirect github.com/xlab/treeprint v1.2.0 // indirect github.com/ykadowak/zerologlint v0.1.5 // indirect diff --git a/go.sum b/go.sum index bde9ebb53..da266b16f 100644 --- a/go.sum +++ b/go.sum @@ -948,6 +948,38 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1H github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0= github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= +github.com/pion/datachannel v1.5.9 h1:LpIWAOYPyDrXtU+BW7X0Yt/vGtYxtXQ8ql7dFfYUVZA= +github.com/pion/datachannel v1.5.9/go.mod h1:kDUuk4CU4Uxp82NH4LQZbISULkX/HtzKa4P7ldf9izE= +github.com/pion/dtls/v3 v3.0.3 h1:j5ajZbQwff7Z8k3pE3S+rQ4STvKvXUdKsi/07ka+OWM= +github.com/pion/dtls/v3 v3.0.3/go.mod h1:weOTUyIV4z0bQaVzKe8kpaP17+us3yAuiQsEAG1STMU= +github.com/pion/ice/v4 v4.0.2 h1:1JhBRX8iQLi0+TfcavTjPjI6GO41MFn4CeTBX+Y9h5s= +github.com/pion/ice/v4 v4.0.2/go.mod h1:DCdqyzgtsDNYN6/3U8044j3U7qsJ9KFJC92VnOWHvXg= +github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI= +github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y= +github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= +github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= +github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= +github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.14 h1:KCkGV3vJ+4DAJmvP0vaQShsb0xkRfWkO540Gy102KyE= +github.com/pion/rtcp v1.2.14/go.mod h1:sn6qjxvnwyAkkPzPULIbVqSKI5Dv54Rv7VG0kNxh9L4= +github.com/pion/rtp v1.8.9 h1:E2HX740TZKaqdcPmf4pw6ZZuG8u5RlMMt+l3dxeu6Wk= +github.com/pion/rtp v1.8.9/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU= +github.com/pion/sctp v1.8.33 h1:dSE4wX6uTJBcNm8+YlMg7lw1wqyKHggsP5uKbdj+NZw= +github.com/pion/sctp v1.8.33/go.mod h1:beTnqSzewI53KWoG3nqB282oDMGrhNxBdb+JZnkCwRM= +github.com/pion/sdp/v3 v3.0.9 h1:pX++dCHoHUwq43kuwf3PyJfHlwIj4hXA7Vrifiq0IJY= +github.com/pion/sdp/v3 v3.0.9/go.mod h1:B5xmvENq5IXJimIO4zfp6LAe1fD9N+kFv+V/1lOdz8M= +github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= +github.com/pion/srtp/v3 v3.0.4/go.mod h1:1Jx3FwDoxpRaTh1oRV8A/6G1BnFL+QI82eK4ms8EEJQ= +github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= +github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= +github.com/pion/turn/v4 v4.0.0 h1:qxplo3Rxa9Yg1xXDxxH8xaqcyGUtbHYw4QSCvmFWvhM= +github.com/pion/turn/v4 v4.0.0/go.mod h1:MuPDkm15nYSklKpN8vWJ9W2M0PlyQZqYt1McGuxG7mA= +github.com/pion/webrtc/v4 v4.0.0 h1:x8ec7uJQPP3D1iI8ojPAiTOylPI7Fa7QgqZrhpLyqZ8= +github.com/pion/webrtc/v4 v4.0.0/go.mod h1:SfNn8CcFxR6OUVjLXVslAQ3a3994JhyE3Hw1jAuqEto= github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= @@ -1209,6 +1241,8 @@ github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1 github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go index 17e6ac870..16a666573 100644 --- a/ipn/ipnstate/ipnstate.go +++ b/ipn/ipnstate/ipnstate.go @@ -651,7 +651,18 @@ table tbody tr:nth-child(even) td { background-color: #f5f5f5; } if ps.Relay != "" && ps.CurAddr == "" { f("relay %s", html.EscapeString(ps.Relay)) } else if ps.CurAddr != "" { - f("direct %s", html.EscapeString(ps.CurAddr)) + // Check if this is a WebRTC connection (magic IP 127.3.3.41) + if strings.HasPrefix(ps.CurAddr, "127.3.3.41:") { + // Extract the actual remote address if present + if idx := strings.Index(ps.CurAddr, " ("); idx > 0 { + remoteAddr := ps.CurAddr[idx+2 : len(ps.CurAddr)-1] // Extract address from " (addr)" + f("webrtc %s", html.EscapeString(remoteAddr)) + } else { + f("webrtc") + } + } else { + f("direct %s", html.EscapeString(ps.CurAddr)) + } } } diff --git a/rtclib/signaling.go b/rtclib/signaling.go new file mode 100644 index 000000000..d91c10c25 --- /dev/null +++ b/rtclib/signaling.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package rtclib + +import "github.com/pion/webrtc/v4" + +// Signaling message types. +const ( + MessageTypeOffer = "offer" + MessageTypeAnswer = "answer" + MessageTypeCandidate = "candidate" +) + +// SignalingMessage represents a message exchanged over the signaling channel. +type SignalingMessage struct { + Type string `json:"type"` // "offer", "answer", "candidate" + From string `json:"from"` // sender's disco public key (hex) + To string `json:"to"` // recipient's disco public key (hex) + Offer *webrtc.SessionDescription `json:"offer,omitempty"` + Answer *webrtc.SessionDescription `json:"answer,omitempty"` + Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"` +} + +// SignalHandler defines callbacks for handling incoming signaling messages. +type SignalHandler interface { + // HandleOffer is called when an offer is received from a peer. + HandleOffer(from, to string, offer *webrtc.SessionDescription) + + // HandleAnswer is called when an answer is received from a peer. + HandleAnswer(from, to string, answer *webrtc.SessionDescription) + + // HandleCandidate is called when an ICE candidate is received from a peer. + HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) +} + +// Signaller defines the interface for WebRTC signaling implementations. +type Signaller interface { + // Start begins the signaling connection with the provided handler. + Start(handler SignalHandler) error + + // Offer sends an SDP offer to a peer. + Offer(from, to string, offer *webrtc.SessionDescription) error + + // Answer sends an SDP answer to a peer. + Answer(from, to string, answer *webrtc.SessionDescription) error + + // Candidate sends an ICE candidate to a peer. + Candidate(from, to string, candidate *webrtc.ICECandidateInit) error + + // Close shuts down the signaling connection. + Close() error +} diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 0811ecc9f..a62e0a344 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -3292,6 +3292,14 @@ const DerpMagicIP = "127.3.3.40" var DerpMagicIPAddr = netip.MustParseAddr(DerpMagicIP) +// WebRTCMagicIP is a fake WireGuard endpoint IP address that means +// to use WebRTC data channel for packet transmission. +// +// Mnemonic: 127.3.3.41 is one above DerpMagicIP for WebRTC. +const WebRTCMagicIP = "127.3.3.41" + +var WebRTCMagicIPAddr = netip.MustParseAddr(WebRTCMagicIP) + // EarlyNoise is the early payload that's sent over Noise but before the HTTP/2 // handshake when connecting to the coordination server. // diff --git a/wgengine/magicsock/debugknobs.go b/wgengine/magicsock/debugknobs.go index 580d954c0..230f679ed 100644 --- a/wgengine/magicsock/debugknobs.go +++ b/wgengine/magicsock/debugknobs.go @@ -66,6 +66,13 @@ var ( // suppressing/dropping inbound/outbound [disco.Ping] messages, forcing // all peer communication over DERP or peer relay. debugNeverDirectUDP = envknob.RegisterBool("TS_DEBUG_NEVER_DIRECT_UDP") + // debugWebRTCSignalingURL sets the WebRTC signaling server URL for + // establishing WebRTC peer connections. When set, magicsock will attempt + // to use WebRTC as an additional path for peer communication. + debugWebRTCSignalingURL = envknob.RegisterString("TS_DEBUG_WEBRTC_SIGNALING_URL") + // debugAlwaysWebRTC forces all peer communication over WebRTC by + // suppressing disco pings to direct UDP and DERP addresses. + debugAlwaysWebRTC = envknob.RegisterBool("TS_DEBUG_ALWAYS_USE_WEBRTC") // Hey you! Adding a new debugknob? Make sure to stub it out in the // debugknobs_stubs.go file too. ) diff --git a/wgengine/magicsock/debugknobs_stubs.go b/wgengine/magicsock/debugknobs_stubs.go index c156ff8a7..ec8a66fcb 100644 --- a/wgengine/magicsock/debugknobs_stubs.go +++ b/wgengine/magicsock/debugknobs_stubs.go @@ -7,6 +7,7 @@ package magicsock import ( "net/netip" + "os" "tailscale.com/types/opt" ) @@ -32,3 +33,5 @@ func inTest() bool { return false } func debugPeerMap() bool { return false } func pretendpoints() []netip.AddrPort { return []netip.AddrPort{} } func debugNeverDirectUDP() bool { return false } +func debugWebRTCSignalingURL() string { return os.Getenv("TS_DEBUG_WEBRTC_SIGNALING_URL") } +func debugAlwaysWebRTC() bool { return false } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index b8d3b96be..8cf00c9f7 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -1076,7 +1076,13 @@ func (de *endpoint) send(buffs [][]byte, offset int) error { } } var err error - if udpAddr.ap.IsValid() { + // Check if this is a WebRTC address and route accordingly + if udpAddr.ap.IsValid() && udpAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr { + // Pack all buffs into one SCTP message. See sendWebRTCBatch for why. + if err = de.c.sendWebRTCBatch(de.publicKey, buffs, offset); err != nil { + return err + } + } else if udpAddr.ap.IsValid() { _, err = de.c.sendUDPBatch(udpAddr, buffs, offset) // If the error is known to indicate that the endpoint is no longer @@ -1295,6 +1301,9 @@ func (de *endpoint) startDiscoPingLocked(ep epAddr, now mono.Time, purpose disco if debugNeverDirectUDP() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.DerpMagicIPAddr { return } + if debugAlwaysWebRTC() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.WebRTCMagicIPAddr { + return + } epDisco := de.disco.Load() if epDisco == nil { return @@ -1815,10 +1824,13 @@ type epAddr struct { vni packet.VirtualNetworkID // vni.IsSet() indicates if this [epAddr] involves a Geneve header } -// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr, +// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr or WebRTCMagicIPAddr, // and a VNI is not set. func (e epAddr) isDirect() bool { - return e.ap.IsValid() && e.ap.Addr() != tailcfg.DerpMagicIPAddr && !e.vni.IsSet() + return e.ap.IsValid() && + e.ap.Addr() != tailcfg.DerpMagicIPAddr && + e.ap.Addr() != tailcfg.WebRTCMagicIPAddr && + !e.vni.IsSet() } func (e epAddr) String() string { @@ -1871,6 +1883,28 @@ func betterAddr(a, b addrQuality) bool { return false } + // WebRTC path priority: Direct UDP > WebRTC > Peer Relay/DERP + aIsWebRTC := a.ap.Addr() == tailcfg.WebRTCMagicIPAddr + bIsWebRTC := b.ap.Addr() == tailcfg.WebRTCMagicIPAddr + aIsDERP := a.ap.Addr() == tailcfg.DerpMagicIPAddr + bIsDERP := b.ap.Addr() == tailcfg.DerpMagicIPAddr + + // Direct paths beat WebRTC + if a.isDirect() && bIsWebRTC { + return true + } + if b.isDirect() && aIsWebRTC { + return false + } + + // WebRTC beats DERP and relay (VNI) + if aIsWebRTC && (bIsDERP || b.vni.IsSet()) { + return true + } + if bIsWebRTC && (aIsDERP || a.vni.IsSet()) { + return false + } + // Each address starts with a set of points (from 0 to 100) that // represents how much faster they are than the highest-latency // endpoint. For example, if a has latency 200ms and b has latency @@ -1891,19 +1925,26 @@ func betterAddr(a, b addrQuality) bool { // addresses, and prefer link-local unicast addresses over other types // of private IP addresses since it's definitionally more likely that // they'll be on the same network segment than a general private IP. - if a.ap.Addr().IsLoopback() { - aPoints += 50 - } else if a.ap.Addr().IsLinkLocalUnicast() { - aPoints += 30 - } else if a.ap.Addr().IsPrivate() { - aPoints += 20 + // + // Exclude magic IPs (DERP, WebRTC) from these bonuses as they're not + // real network paths. + if !aIsDERP && !aIsWebRTC { + if a.ap.Addr().IsLoopback() { + aPoints += 50 + } else if a.ap.Addr().IsLinkLocalUnicast() { + aPoints += 30 + } else if a.ap.Addr().IsPrivate() { + aPoints += 20 + } } - if b.ap.Addr().IsLoopback() { - bPoints += 50 - } else if b.ap.Addr().IsLinkLocalUnicast() { - bPoints += 30 - } else if b.ap.Addr().IsPrivate() { - bPoints += 20 + if !bIsDERP && !bIsWebRTC { + if b.ap.Addr().IsLoopback() { + bPoints += 50 + } else if b.ap.Addr().IsLinkLocalUnicast() { + bPoints += 30 + } else if b.ap.Addr().IsPrivate() { + bPoints += 20 + } } // Prefer IPv6 for being a bit more robust, as long as @@ -2035,6 +2076,14 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { ps.PeerRelay = udpAddr.String() } else { ps.CurAddr = udpAddr.String() + // If this is a WebRTC connection, append the actual remote address + if udpAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr && de.c.webrtcMgr != nil { + if disco := de.disco.Load(); disco != nil { + if remoteAddr := de.c.webrtcMgr.getRemoteAddr(disco.key); remoteAddr.IsValid() { + ps.CurAddr = fmt.Sprintf("%s (%s)", ps.CurAddr, remoteAddr) + } + } + } } } } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index f13e31554..dbbdf061c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -94,6 +94,7 @@ type Path string const ( PathDirectIPv4 Path = "direct_ipv4" PathDirectIPv6 Path = "direct_ipv6" + PathWebRTC Path = "webrtc" PathDERP Path = "derp" PathPeerRelayIPv4 Path = "peer_relay_ipv4" PathPeerRelayIPv6 Path = "peer_relay_ipv6" @@ -109,6 +110,14 @@ type pathLabel struct { Path Path } +// webrtcReadResult is the result of reading a packet from a WebRTC data channel. +// It is similar to derpReadResult but for WebRTC connections. +type webrtcReadResult struct { + n int // length of data in buf + src key.NodePublic // sender's node public key + buf []byte // packet data; nil signals the receiver to ignore this message +} + // metrics in wgengine contains the usermetrics counters for magicsock, it // is however a bit special. All them metrics are labeled, but looking up // the metric everytime we need to record it has an overhead, and includes @@ -119,6 +128,7 @@ type metrics struct { // labeled by the path the packet took. inboundPacketsIPv4Total expvar.Int inboundPacketsIPv6Total expvar.Int + inboundPacketsWebRTCTotal expvar.Int inboundPacketsDERPTotal expvar.Int inboundPacketsPeerRelayIPv4Total expvar.Int inboundPacketsPeerRelayIPv6Total expvar.Int @@ -127,6 +137,7 @@ type metrics struct { // labeled by the path the packet took. inboundBytesIPv4Total expvar.Int inboundBytesIPv6Total expvar.Int + inboundBytesWebRTCTotal expvar.Int inboundBytesDERPTotal expvar.Int inboundBytesPeerRelayIPv4Total expvar.Int inboundBytesPeerRelayIPv6Total expvar.Int @@ -135,6 +146,7 @@ type metrics struct { // labeled by the path the packet took. outboundPacketsIPv4Total expvar.Int outboundPacketsIPv6Total expvar.Int + outboundPacketsWebRTCTotal expvar.Int outboundPacketsDERPTotal expvar.Int outboundPacketsPeerRelayIPv4Total expvar.Int outboundPacketsPeerRelayIPv6Total expvar.Int @@ -143,6 +155,7 @@ type metrics struct { // labeled by the path the packet took. outboundBytesIPv4Total expvar.Int outboundBytesIPv6Total expvar.Int + outboundBytesWebRTCTotal expvar.Int outboundBytesDERPTotal expvar.Int outboundBytesPeerRelayIPv4Total expvar.Int outboundBytesPeerRelayIPv6Total expvar.Int @@ -211,6 +224,10 @@ type Conn struct { // It must have buffer size > 0; see issue 3736. derpRecvCh chan derpReadResult + // webrtcRecvCh is used by receiveWebRTC to read WebRTC messages. + // It must have buffer size > 0, similar to derpRecvCh. + webrtcRecvCh chan webrtcReadResult + // bind is the wireguard-go conn.Bind for Conn. bind *connBind @@ -343,6 +360,10 @@ type Conn struct { // [tailscale.com/net/udprelay.Server] endpoints. relayManager relayManager + // webrtcMgr manages WebRTC connections for peers. + // May be nil if WebRTC is disabled (no TS_DEBUG_WEBRTC_SIGNALING_URL). + webrtcMgr *webrtcManager + // discoInfo is the state for an active peer DiscoKey. discoInfo map[key.DiscoPublic]*discoInfo @@ -575,7 +596,8 @@ func newConn(logf logger.Logf) *Conn { discoPrivate := key.NewDisco() c := &Conn{ logf: logf, - derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736 + derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736 + webrtcRecvCh: make(chan webrtcReadResult, 64), // must be buffered, similar to derpRecvCh derpStarted: make(chan struct{}), peerLastDerp: make(map[key.NodePublic]int), peerMap: newPeerMap(), @@ -727,6 +749,16 @@ func NewConn(opts Options) (*Conn, error) { } c.logf("magicsock: disco key = %v", c.discoAtomic.Short()) + + // Initialize WebRTC manager if signaling server URL is set + if signalingURL := debugWebRTCSignalingURL(); signalingURL != "" { + c.logf("magicsock: initializing WebRTC with signaling server %s", signalingURL) + c.webrtcMgr = newWebRTCManager(c, signalingURL) + if c.webrtcMgr == nil { + c.logf("magicsock: failed to initialize WebRTC manager") + } + } + return c, nil } @@ -736,6 +768,7 @@ func NewConn(opts Options) (*Conn, error) { func registerMetrics(reg *usermetric.Registry) *metrics { pathDirectV4 := pathLabel{Path: PathDirectIPv4} pathDirectV6 := pathLabel{Path: PathDirectIPv6} + pathWebRTC := pathLabel{Path: PathWebRTC} pathDERP := pathLabel{Path: PathDERP} pathPeerRelayV4 := pathLabel{Path: PathPeerRelayIPv4} pathPeerRelayV6 := pathLabel{Path: PathPeerRelayIPv6} @@ -770,21 +803,25 @@ func registerMetrics(reg *usermetric.Registry) *metrics { // Map clientmetrics to the usermetric counters. metricRecvDataPacketsIPv4.Register(&m.inboundPacketsIPv4Total) metricRecvDataPacketsIPv6.Register(&m.inboundPacketsIPv6Total) + metricRecvDataPacketsWebRTC.Register(&m.inboundPacketsWebRTCTotal) metricRecvDataPacketsDERP.Register(&m.inboundPacketsDERPTotal) metricRecvDataPacketsPeerRelayIPv4.Register(&m.inboundPacketsPeerRelayIPv4Total) metricRecvDataPacketsPeerRelayIPv6.Register(&m.inboundPacketsPeerRelayIPv6Total) metricRecvDataBytesIPv4.Register(&m.inboundBytesIPv4Total) metricRecvDataBytesIPv6.Register(&m.inboundBytesIPv6Total) + metricRecvDataBytesWebRTC.Register(&m.inboundBytesWebRTCTotal) metricRecvDataBytesDERP.Register(&m.inboundBytesDERPTotal) metricRecvDataBytesPeerRelayIPv4.Register(&m.inboundBytesPeerRelayIPv4Total) metricRecvDataBytesPeerRelayIPv6.Register(&m.inboundBytesPeerRelayIPv6Total) metricSendDataPacketsIPv4.Register(&m.outboundPacketsIPv4Total) metricSendDataPacketsIPv6.Register(&m.outboundPacketsIPv6Total) + metricSendDataPacketsWebRTC.Register(&m.outboundPacketsWebRTCTotal) metricSendDataPacketsDERP.Register(&m.outboundPacketsDERPTotal) metricSendDataPacketsPeerRelayIPv4.Register(&m.outboundPacketsPeerRelayIPv4Total) metricSendDataPacketsPeerRelayIPv6.Register(&m.outboundPacketsPeerRelayIPv6Total) metricSendDataBytesIPv4.Register(&m.outboundBytesIPv4Total) metricSendDataBytesIPv6.Register(&m.outboundBytesIPv6Total) + metricSendDataBytesWebRTC.Register(&m.outboundBytesWebRTCTotal) metricSendDataBytesDERP.Register(&m.outboundBytesDERPTotal) metricSendDataBytesPeerRelayIPv4.Register(&m.outboundBytesPeerRelayIPv4Total) metricSendDataBytesPeerRelayIPv6.Register(&m.outboundBytesPeerRelayIPv6Total) @@ -796,24 +833,28 @@ func registerMetrics(reg *usermetric.Registry) *metrics { inboundPacketsTotal.Set(pathDirectV4, &m.inboundPacketsIPv4Total) inboundPacketsTotal.Set(pathDirectV6, &m.inboundPacketsIPv6Total) + inboundPacketsTotal.Set(pathWebRTC, &m.inboundPacketsWebRTCTotal) inboundPacketsTotal.Set(pathDERP, &m.inboundPacketsDERPTotal) inboundPacketsTotal.Set(pathPeerRelayV4, &m.inboundPacketsPeerRelayIPv4Total) inboundPacketsTotal.Set(pathPeerRelayV6, &m.inboundPacketsPeerRelayIPv6Total) inboundBytesTotal.Set(pathDirectV4, &m.inboundBytesIPv4Total) inboundBytesTotal.Set(pathDirectV6, &m.inboundBytesIPv6Total) + inboundBytesTotal.Set(pathWebRTC, &m.inboundBytesWebRTCTotal) inboundBytesTotal.Set(pathDERP, &m.inboundBytesDERPTotal) inboundBytesTotal.Set(pathPeerRelayV4, &m.inboundBytesPeerRelayIPv4Total) inboundBytesTotal.Set(pathPeerRelayV6, &m.inboundBytesPeerRelayIPv6Total) outboundPacketsTotal.Set(pathDirectV4, &m.outboundPacketsIPv4Total) outboundPacketsTotal.Set(pathDirectV6, &m.outboundPacketsIPv6Total) + outboundPacketsTotal.Set(pathWebRTC, &m.outboundPacketsWebRTCTotal) outboundPacketsTotal.Set(pathDERP, &m.outboundPacketsDERPTotal) outboundPacketsTotal.Set(pathPeerRelayV4, &m.outboundPacketsPeerRelayIPv4Total) outboundPacketsTotal.Set(pathPeerRelayV6, &m.outboundPacketsPeerRelayIPv6Total) outboundBytesTotal.Set(pathDirectV4, &m.outboundBytesIPv4Total) outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total) + outboundBytesTotal.Set(pathWebRTC, &m.outboundBytesWebRTCTotal) outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal) outboundBytesTotal.Set(pathPeerRelayV4, &m.outboundBytesPeerRelayIPv4Total) outboundBytesTotal.Set(pathPeerRelayV6, &m.outboundBytesPeerRelayIPv6Total) @@ -828,21 +869,25 @@ func registerMetrics(reg *usermetric.Registry) *metrics { func deregisterMetrics() { metricRecvDataPacketsIPv4.UnregisterAll() metricRecvDataPacketsIPv6.UnregisterAll() + metricRecvDataPacketsWebRTC.UnregisterAll() metricRecvDataPacketsDERP.UnregisterAll() metricRecvDataPacketsPeerRelayIPv4.UnregisterAll() metricRecvDataPacketsPeerRelayIPv6.UnregisterAll() metricRecvDataBytesIPv4.UnregisterAll() metricRecvDataBytesIPv6.UnregisterAll() + metricRecvDataBytesWebRTC.UnregisterAll() metricRecvDataBytesDERP.UnregisterAll() metricRecvDataBytesPeerRelayIPv4.UnregisterAll() metricRecvDataBytesPeerRelayIPv6.UnregisterAll() metricSendDataPacketsIPv4.UnregisterAll() metricSendDataPacketsIPv6.UnregisterAll() + metricSendDataPacketsWebRTC.UnregisterAll() metricSendDataPacketsDERP.UnregisterAll() metricSendDataPacketsPeerRelayIPv4.UnregisterAll() metricSendDataPacketsPeerRelayIPv6.UnregisterAll() metricSendDataBytesIPv4.UnregisterAll() metricSendDataBytesIPv6.UnregisterAll() + metricSendDataBytesWebRTC.UnregisterAll() metricSendDataBytesDERP.UnregisterAll() metricSendDataBytesPeerRelayIPv4.UnregisterAll() metricSendDataBytesPeerRelayIPv6.UnregisterAll() @@ -1630,6 +1675,9 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) // IPv6 address when the local machine doesn't have IPv6 support // returns (false, nil); it's not an error, but nothing was sent. func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool, isGeneveEncap bool) (sent bool, err error) { + if addr.Addr() == tailcfg.WebRTCMagicIPAddr { + return c.sendWebRTC(addr, pubKey, b) + } if addr.Addr() != tailcfg.DerpMagicIPAddr { return c.sendUDP(addr, b, isDisco, isGeneveEncap) } @@ -1671,6 +1719,170 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, is return false, errDropDerpPacket } +// webrtcBatchMagic is the first byte of a batched WebRTC SCTP message. +// WireGuard packets start with 0x01–0x04, disco packets start with 0x54 ('T'), +// so 0xBA is unambiguous. +const webrtcBatchMagic = byte(0xBA) + +// sendWebRTCBatch packs all buffs into a single SCTP message and sends it. +// Batching is the critical throughput optimization: the per-packet WebRTC +// path calls rwc.Write once per WireGuard packet, producing one SCTP message, +// one DTLS record, and one UDP send per packet. Packing N packets into one +// SCTP message reduces that to a single write — the same advantage +// sendUDPBatch (sendmmsg) gives the regular UDP path. +// +// Wire format for N>1: [0xBA magic][2-byte BE len][packet]...[2-byte BE len][packet] +// Single packet: sent as-is with no framing overhead. +func (c *Conn) sendWebRTCBatch(pubKey key.NodePublic, buffs [][]byte, offset int) error { + if c.webrtcMgr == nil { + return nil + } + + // Resolve endpoint and disco key once for the whole batch. + c.mu.Lock() + ep, ok := c.peerMap.endpointForNodeKey(pubKey) + c.mu.Unlock() + if !ok || ep == nil { + return nil + } + disco := ep.disco.Load() + if disco == nil { + return nil + } + + if len(buffs) == 1 { + // Fast path: single packet, no framing overhead. + b := buffs[0][offset:] + if err := c.webrtcMgr.sendPacket(disco.key, b); err != nil { + return err + } + c.metrics.outboundPacketsWebRTCTotal.Add(1) + c.metrics.outboundBytesWebRTCTotal.Add(int64(len(b))) + return nil + } + + // Multi-packet batch path. + size := 1 // magic byte + for _, b := range buffs { + size += 2 + len(b[offset:]) + } + batch := make([]byte, size) + batch[0] = webrtcBatchMagic + pos := 1 + var totalBytes int64 + for _, b := range buffs { + pkt := b[offset:] + binary.BigEndian.PutUint16(batch[pos:], uint16(len(pkt))) + pos += 2 + copy(batch[pos:], pkt) + pos += len(pkt) + totalBytes += int64(len(pkt)) + } + if err := c.webrtcMgr.sendPacket(disco.key, batch); err != nil { + return err + } + c.metrics.outboundPacketsWebRTCTotal.Add(int64(len(buffs))) + c.metrics.outboundBytesWebRTCTotal.Add(totalBytes) + return nil +} + +// sendWebRTC sends a packet over WebRTC data channel. +func (c *Conn) sendWebRTC(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) { + if c.webrtcMgr == nil { + return false, nil + } + + // Find the endpoint by public key + c.mu.Lock() + ep, ok := c.peerMap.endpointForNodeKey(pubKey) + c.mu.Unlock() + + if !ok { + return false, nil + } + + // Get the disco key for this endpoint + disco := ep.disco.Load() + if disco == nil { + return false, nil + } + + // Send via WebRTC manager + if err := c.webrtcMgr.sendPacket(disco.key, b); err != nil { + return false, err + } + + // Update metrics + c.metrics.outboundPacketsWebRTCTotal.Add(1) + c.metrics.outboundBytesWebRTCTotal.Add(int64(len(b))) + + return true, nil +} + +// receiveWebRTC handles packets received from WebRTC data channels. +// This is called by webrtcManager when data arrives on a data channel. +// receiveWebRTC is called by webrtcManager when data arrives on a WebRTC data channel. +// It queues the packet for processing by wireguard-go through the webrtcRecvCh channel. +func (c *Conn) receiveWebRTC(b []byte, srcNodeKey key.NodePublic) { + // Copy into a fresh slice: b belongs to the reader goroutine's reusable + // buffer which will be overwritten on the next Read call. + pkt := make([]byte, len(b)) + copy(pkt, b) + + select { + case c.webrtcRecvCh <- webrtcReadResult{n: len(pkt), src: srcNodeKey, buf: pkt}: + case <-c.connCtx.Done(): + default: + c.logf("webrtc: dropped packet from %v, receive channel full", srcNodeKey.ShortString()) + } +} + +// processWebRTCReadResult processes a WebRTC packet received from the webrtcRecvCh. +// It's similar to processDERPReadResult but for WebRTC packets. +func (c *Conn) processWebRTCReadResult(wr webrtcReadResult, b []byte) (n int, ep *endpoint) { + if wr.buf == nil { + return 0, nil + } + + n = wr.n + ncopy := copy(b, wr.buf[:n]) + if ncopy != n { + err := fmt.Errorf("received WebRTC packet of length %d that's too big for WireGuard buf size %d", n, ncopy) + c.logf("magicsock: %v", err) + return 0, nil + } + + srcAddr := epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)} + + // Check if this looks like a disco packet + pt, isGeneveEncap := packetLooksLike(b[:n]) + if pt == packetLooksLikeDisco && !isGeneveEncap { + c.handleDiscoMessage(b[:n], srcAddr, false, wr.src, discoRXPathWebRTC) + return 0, nil + } + + // Find the endpoint by node key + var ok bool + c.mu.Lock() + ep, ok = c.peerMap.endpointForNodeKey(wr.src) + c.mu.Unlock() + + if !ok { + // We don't know anything about this node key + return 0, nil + } + + ep.noteRecvActivity(srcAddr, mono.Now()) + if update := c.connCounter.Load(); update != nil { + update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, n, true) + } + + c.metrics.inboundPacketsWebRTCTotal.Add(1) + c.metrics.inboundBytesWebRTCTotal.Add(int64(n)) + + return n, ep +} + type receiveBatch struct { msgs []ipv6.Message } @@ -2005,7 +2217,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key. if !dstKey.IsZero() { node = dstKey.ShortString() } - c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) + c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, pathStr(dst.String()), disco.MessageSummary(m), len(pkt)) } if isDERP { metricSentDiscoDERP.Add(1) @@ -2045,6 +2257,7 @@ type discoRXPath string const ( discoRXPathUDP discoRXPath = "UDP socket" discoRXPathDERP discoRXPath = "DERP" + discoRXPathWebRTC discoRXPath = "WebRTC" discoRXPathRawSocket discoRXPath = "raw socket" ) @@ -2356,13 +2569,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake if isVia { c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(), - ep.publicKey.ShortString(), derpStr(src.String()), + ep.publicKey.ShortString(), pathStr(src.String()), len(via.AddrPorts)) c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via) } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", c.discoAtomic.Short(), epDisco.short, - ep.publicKey.ShortString(), derpStr(src.String()), + ep.publicKey.ShortString(), pathStr(src.String()), len(cmm.MyNumber)) go ep.handleCallMeMaybe(cmm) } @@ -2408,7 +2621,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake if isResp { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints", c.discoAtomic.Short(), epDisco.short, - ep.publicKey.ShortString(), derpStr(src.String()), + ep.publicKey.ShortString(), pathStr(src.String()), msgType, len(resp.AddrPorts)) c.relayManager.handleRxDiscoMsg(c, resp, nodeKey, di.discoKey, src) @@ -2422,7 +2635,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v", c.discoAtomic.Short(), epDisco.short, - ep.publicKey.ShortString(), derpStr(src.String()), + ep.publicKey.ShortString(), pathStr(src.String()), msgType, req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString()) } @@ -3107,6 +3320,11 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee } ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap + + // Start WebRTC connection if not already started + if c.webrtcMgr != nil && !n.DiscoKey().IsZero() { + c.webrtcMgr.startConnection(ep) + } continue } @@ -3170,6 +3388,11 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + // Start WebRTC connection to this peer if WebRTC is enabled + if c.webrtcMgr != nil && !n.DiscoKey().IsZero() { + c.webrtcMgr.startConnection(ep) + } } // If the set of nodes changed since the last SetNetworkMap, the @@ -3287,9 +3510,9 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error) return nil, 0, errors.New("magicsock: connBind already open") } c.closed = false - fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP} + fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP, c.receiveWebRTC} if runtime.GOOS == "js" { - fns = []conn.ReceiveFunc{c.receiveDERP} + fns = []conn.ReceiveFunc{c.receiveDERP, c.receiveWebRTC} } // TODO: Combine receiveIPv4 and receiveIPv6 and receiveIP into a single // closure that closes over a *RebindingUDPConn? @@ -3366,6 +3589,14 @@ func (c *Conn) Close() error { ep.stopAndReset() }) + // Close WebRTC manager if initialized + if c.webrtcMgr != nil { + c.webrtcMgr.close() + c.webrtcMgr = nil + } + + close(c.webrtcRecvCh) + c.closed = true c.connCtxCancel() c.closeAllDerpLocked("conn-close") @@ -3920,9 +4151,20 @@ func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { } } +// pathStr formats endpoint addresses for display, replacing magic IPs with readable names. +// It replaces DERP IPs with "derp-" and WebRTC IPs with "webrtc-". +func pathStr(s string) string { + s = derpStr(s) + s = webrtcStr(s) + return s +} + // derpStr replaces DERP IPs in s with "derp-". func derpStr(s string) string { return strings.ReplaceAll(s, "127.3.3.40:", "derp-") } +// webrtcStr replaces WebRTC IPs in s with "webrtc-". +func webrtcStr(s string) string { return strings.ReplaceAll(s, "127.3.3.41:", "webrtc-") } + // epAddrEndpointCache is a mutex-free single-element cache, mapping from // a single [epAddr] to a single [*endpoint]. type epAddrEndpointCache struct { @@ -4003,11 +4245,13 @@ var ( metricRecvDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_derp") metricRecvDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv4") metricRecvDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv6") + metricRecvDataPacketsWebRTC = clientmetric.NewAggregateCounter("magicsock_recv_data_webrtc") metricRecvDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv4") metricRecvDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv6") metricSendDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_send_data_derp") metricSendDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv4") metricSendDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv6") + metricSendDataPacketsWebRTC = clientmetric.NewAggregateCounter("magicsock_send_data_webrtc") metricSendDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv4") metricSendDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv6") @@ -4015,11 +4259,13 @@ var ( metricRecvDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_derp") metricRecvDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv4") metricRecvDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv6") + metricRecvDataBytesWebRTC = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_webrtc") metricRecvDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv4") metricRecvDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv6") metricSendDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_derp") metricSendDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv4") metricSendDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv6") + metricSendDataBytesWebRTC = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_webrtc") metricSendDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv4") metricSendDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv6") @@ -4139,6 +4385,16 @@ func (c *Conn) SetLastNetcheckReportForTest(ctx context.Context, report *netchec c.lastNetCheckReport.Store(report) } +// findEndpointByDisco returns the first endpoint with the given disco key, or nil if not found. +func (c *Conn) findEndpointByDisco(dk key.DiscoPublic) *endpoint { + var found *endpoint + c.peerMap.forEachEndpointWithDiscoKey(dk, func(ep *endpoint) bool { + found = ep + return false // stop after first match + }) + return found +} + // lazyEndpoint is a wireguard [conn.Endpoint] for when magicsock received a // non-disco (presumably WireGuard) packet from a UDP address from which we // can't map to a Tailscale peer. But WireGuard most likely can, once it diff --git a/wgengine/magicsock/webrtc.go b/wgengine/magicsock/webrtc.go new file mode 100644 index 000000000..3955d25c9 --- /dev/null +++ b/wgengine/magicsock/webrtc.go @@ -0,0 +1,747 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pion/webrtc/v4" + "github.com/tailscale/wireguard-go/conn" + "tailscale.com/rtclib" + "tailscale.com/tailcfg" + "tailscale.com/tstime/mono" + "tailscale.com/types/key" + "tailscale.com/types/logger" +) + +// webrtcConnState represents the state of a WebRTC connection. +type webrtcConnState int + +const ( + webrtcStateIdle webrtcConnState = iota + webrtcStateConnecting + webrtcStateConnected + webrtcStateFailed + webrtcStateClosed +) + +// dataChannelRW is the detached io.ReadWriteCloser for a WebRTC DataChannel. +// It is stored via atomic.Pointer so the hot send path can retrieve it without +// holding the webrtcManager mutex. +type dataChannelRW struct { + io.ReadWriteCloser +} + +// webrtcPeerState tracks WebRTC connection state for a single peer. +type webrtcPeerState struct { + ep *endpoint + peerConn *webrtc.PeerConnection + dataChannel *webrtc.DataChannel + dcRW atomic.Pointer[dataChannelRW] // non-nil once the DataChannel is open + localDisco key.DiscoPublic + remoteDisco key.DiscoPublic + remoteNodeKey key.NodePublic // peer's node public key (for WireGuard) + remoteAddr netip.AddrPort // actual remote address from ICE candidate + state webrtcConnState + lastError error + createdAt time.Time +} + +// webrtcConnectionReadyEvent signals that a WebRTC connection is ready. +type webrtcConnectionReadyEvent struct { + remoteDisco key.DiscoPublic + ep *endpoint +} + +// webrtcManager manages WebRTC connections for magicsock. +type webrtcManager struct { + logf logger.Logf + conn *Conn // parent magicsock.Conn + + mu sync.RWMutex + peerConnectionsByEndpoint map[*endpoint]*webrtcPeerState + peerConnectionsByDisco map[key.DiscoPublic]*webrtcPeerState + + signalingClient *signalingClient + + // Control channels + startConnectionCh chan *endpoint + connectionReadyCh chan webrtcConnectionReadyEvent + closeCh chan struct{} + runLoopStoppedCh chan struct{} + + // WebRTC API configuration + api *webrtc.API +} + +// Ensure webrtcManager implements rtclib.SignalHandler interface. +var _ rtclib.SignalHandler = (*webrtcManager)(nil) + +// newWebRTCManager creates a new WebRTC manager. +func newWebRTCManager(c *Conn, signalingURL string) *webrtcManager { + mgr := newWebRTCManagerBase(c, signalingURL) + + // Create and start signaling client + mgr.signalingClient = newSignalingClient(signalingURL, c.logf) + if err := mgr.signalingClient.Start(mgr); err != nil { + c.logf("webrtc: failed to start signaling client: %v", err) + return nil + } + + // Start event loop + go mgr.runLoop() + + return mgr +} + +// close shuts down the WebRTC manager. +func (m *webrtcManager) close() error { + // Close signaling client first to stop new messages + if m.signalingClient != nil { + if err := m.signalingClient.Close(); err != nil { + m.logf("webrtc: signaling client close error: %v", err) + } + } + + // Signal runLoop to stop + close(m.closeCh) + + // Wait for runLoop to finish with timeout + select { + case <-m.runLoopStoppedCh: + case <-time.After(2 * time.Second): + m.logf("webrtc: close timed out, forcing shutdown") + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Close all peer connections + for _, ps := range m.peerConnectionsByEndpoint { + if ps.peerConn != nil { + ps.peerConn.Close() + } + } + m.peerConnectionsByEndpoint = nil + m.peerConnectionsByDisco = nil + + return nil +} + +// startConnection initiates a WebRTC connection to an endpoint. +func (m *webrtcManager) startConnection(ep *endpoint) { + select { + case m.startConnectionCh <- ep: + case <-m.closeCh: + default: + m.logf("webrtc: startConnection queue full for %v", ep.nodeAddr) + } +} + +// deliverWebRTCMsg delivers one DataChannel message to the receive pipeline. +// It handles both single packets and batches (webrtcBatchMagic framing) so +// the logic is shared between the native detached-reader path and the +// JS/fallback OnMessage callback path. +func (m *webrtcManager) deliverWebRTCMsg(ps *webrtcPeerState, data []byte) { + if len(data) == 0 { + return + } + // Batch: [0xBA magic][2-byte BE len][pkt]...[2-byte BE len][pkt] + if data[0] == webrtcBatchMagic { + data = data[1:] + for len(data) >= 2 { + pktLen := int(binary.BigEndian.Uint16(data)) + data = data[2:] + if pktLen > len(data) { + m.logf("webrtc: batch framing error for peer %v: pktLen %d > remaining %d", + ps.remoteDisco.ShortString(), pktLen, len(data)) + return + } + m.conn.receiveWebRTC(data[:pktLen], ps.remoteNodeKey) + data = data[pktLen:] + } + return + } + m.conn.receiveWebRTC(data, ps.remoteNodeKey) +} + +// runDataChannelReader is the per-peer receive loop used when DetachDataChannels +// is enabled (native builds). It reads directly from the detached io.ReadWriteCloser +// into a reused buffer, avoiding the per-message goroutine wakeup and allocation +// that the OnMessage callback path incurs. +func (m *webrtcManager) runDataChannelReader(ps *webrtcPeerState, rwc io.ReadWriteCloser) { + // Size the buffer to hold the largest possible batch. + // 64 WireGuard packets × ~1420 bytes + framing < 100 KiB; 256 KiB is safe. + buf := make([]byte, 256*1024) + for { + n, err := rwc.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, net.ErrClosed) { + m.logf("webrtc: data channel read error for peer %v: %v", ps.remoteDisco.ShortString(), err) + } + ps.dcRW.Store(nil) + return + } + if n > 0 { + m.deliverWebRTCMsg(ps, buf[:n]) + } + } +} + +// getRemoteAddr returns the actual remote address for a WebRTC peer connection. +func (m *webrtcManager) getRemoteAddr(disco key.DiscoPublic) netip.AddrPort { + m.mu.RLock() + defer m.mu.RUnlock() + + if ps, ok := m.peerConnectionsByDisco[disco]; ok && ps.state == webrtcStateConnected { + return ps.remoteAddr + } + return netip.AddrPort{} +} + +// runLoop is the main event loop for the WebRTC manager. +func (m *webrtcManager) runLoop() { + defer close(m.runLoopStoppedCh) + + for { + select { + case ep := <-m.startConnectionCh: + m.handleStartConnection(ep) + + case event := <-m.connectionReadyCh: + m.handleConnectionReady(event) + + case <-m.closeCh: + return + } + } +} + +// handleStartConnection creates a new WebRTC connection to an endpoint. +func (m *webrtcManager) handleStartConnection(ep *endpoint) { + m.mu.Lock() + + // Check if we already have a connection + if ps, exists := m.peerConnectionsByEndpoint[ep]; exists { + if ps.state == webrtcStateConnecting || ps.state == webrtcStateConnected { + m.mu.Unlock() + return + } + } + + // Get disco keys + localDisco := m.conn.DiscoPublicKey() + disco := ep.disco.Load() + if disco == nil { + m.mu.Unlock() + m.logf("webrtc: cannot start connection, peer has no disco key") + return + } + remoteDisco := disco.key + m.logf("webrtc: starting connection to peer %v (disco %v)", ep.nodeAddr, remoteDisco.ShortString()) + + m.mu.Unlock() + + // Create peer connection + config := webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{ + { + URLs: []string{"stun:stun.l.google.com:19302"}, + }, + }, + ICETransportPolicy: webrtc.ICETransportPolicyAll, + } + + peerConn, err := m.api.NewPeerConnection(config) + if err != nil { + m.logf("webrtc: failed to create peer connection: %v", err) + return + } + + ps := &webrtcPeerState{ + ep: ep, + peerConn: peerConn, + localDisco: localDisco, + remoteDisco: remoteDisco, + remoteNodeKey: ep.publicKey, + state: webrtcStateConnecting, + createdAt: time.Now(), + } + + // Store peer state + m.mu.Lock() + m.peerConnectionsByEndpoint[ep] = ps + m.peerConnectionsByDisco[remoteDisco] = ps + m.mu.Unlock() + + // Set up connection state handler + peerConn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + m.handleConnectionStateChange(ps, state) + }) + + // Set up ICE candidate handler + peerConn.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate != nil { + m.handleLocalICECandidate(ps, candidate) + } + }) + + // Create an unordered, unreliable data channel (MaxRetransmits=0). + // WireGuard is designed to run over raw UDP, which is unordered and + // unreliable. Using an ordered/reliable DataChannel (the default) wraps + // WireGuard in SCTP's reliable-ordered-stream semantics, causing + // head-of-line blocking whenever a packet is lost: SCTP holds back all + // subsequent packets until the missing one is retransmitted and delivered + // in order. That is why throughput over WebRTC was worse than DERP. + // Setting Ordered=false and MaxRetransmits=0 makes the DataChannel behave + // like a UDP socket, which is exactly what WireGuard expects. + unordered := false + maxRetransmits := uint16(0) + dataChannel, err := peerConn.CreateDataChannel("tailscale-wg", &webrtc.DataChannelInit{ + Ordered: &unordered, + MaxRetransmits: &maxRetransmits, + }) + if err != nil { + m.logf("webrtc: failed to create data channel: %v", err) + peerConn.Close() + return + } + + ps.dataChannel = dataChannel + + // Set up data channel handlers. + // With DetachDataChannels enabled, OnMessage cannot be used. Instead we + // call Detach() inside OnOpen to get a raw io.ReadWriteCloser and spin + // up a dedicated reader goroutine, which eliminates per-packet callback + // overhead and goroutine wakeups. + setOnError(dataChannel, func(err error) { + m.logf("webrtc: data channel error for peer %v: %v", remoteDisco.ShortString(), err) + }) + + dataChannel.OnOpen(func() { + // Native: DetachDataChannels was enabled; get a raw io.ReadWriteCloser + // and spin a dedicated reader goroutine (zero per-message allocations). + // JS/fallback: Detach() returns an error; fall back to OnMessage + // callbacks, which is the only API available in the browser. + if rwc, err := dataChannel.Detach(); err == nil { + ps.dcRW.Store(&dataChannelRW{rwc}) + go m.runDataChannelReader(ps, rwc) + } else { + dataChannel.OnMessage(func(msg webrtc.DataChannelMessage) { + m.deliverWebRTCMsg(ps, msg.Data) + }) + } + m.logf("webrtc: data channel opened for peer %v", remoteDisco.ShortString()) + m.connectionReadyCh <- webrtcConnectionReadyEvent{ + remoteDisco: remoteDisco, + ep: ep, + } + }) + + // Create and send offer + offer, err := peerConn.CreateOffer(nil) + if err != nil { + m.logf("webrtc: failed to create offer: %v", err) + peerConn.Close() + return + } + + if err := peerConn.SetLocalDescription(offer); err != nil { + m.logf("webrtc: failed to set local description: %v", err) + peerConn.Close() + return + } + + // Send offer via signaling + if err := m.signalingClient.Offer(localDisco.String(), remoteDisco.String(), &offer); err != nil { + m.logf("webrtc: failed to send offer: %v", err) + peerConn.Close() + return + } + + m.logf("webrtc: sent offer to peer %v", remoteDisco.ShortString()) +} + +// HandleOffer implements rtclib.SignalHandler. +func (m *webrtcManager) HandleOffer(from, to string, offer *webrtc.SessionDescription) { + m.logf("webrtc: received offer from=%s", from) + + var remoteDisco key.DiscoPublic + if err := remoteDisco.UnmarshalText([]byte(from)); err != nil { + m.logf("webrtc: invalid sender disco key: %v", err) + return + } + + m.handleRemoteOffer(remoteDisco, offer) +} + +// HandleAnswer implements rtclib.SignalHandler. +func (m *webrtcManager) HandleAnswer(from, to string, answer *webrtc.SessionDescription) { + m.logf("webrtc: received answer from=%s", from) + + var remoteDisco key.DiscoPublic + if err := remoteDisco.UnmarshalText([]byte(from)); err != nil { + m.logf("webrtc: invalid sender disco key: %v", err) + return + } + + m.handleRemoteAnswer(remoteDisco, answer) +} + +// HandleCandidate implements rtclib.SignalHandler. +func (m *webrtcManager) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) { + m.logf("webrtc: received candidate from=%s", from) + + var remoteDisco key.DiscoPublic + if err := remoteDisco.UnmarshalText([]byte(from)); err != nil { + m.logf("webrtc: invalid sender disco key: %v", err) + return + } + + m.handleRemoteCandidate(remoteDisco, candidate) +} + +// handleRemoteOffer processes an incoming offer from a peer. +func (m *webrtcManager) handleRemoteOffer(remoteDisco key.DiscoPublic, offer *webrtc.SessionDescription) { + + // For incoming connections, we need to find the endpoint by disco key + m.mu.Lock() + ps, exists := m.peerConnectionsByDisco[remoteDisco] + m.mu.Unlock() + + if !exists { + // We received an offer but don't have a connection yet. + // This happens when the remote peer initiated first (glare scenario). + // Find the endpoint by disco key and create peer connection state. + ep := m.conn.findEndpointByDisco(remoteDisco) + if ep == nil { + m.logf("webrtc: received offer from unknown peer %v with no endpoint", remoteDisco.ShortString()) + return + } + + m.logf("webrtc: received offer from peer %v, creating answerer connection", remoteDisco.ShortString()) + + // Create peer connection for incoming offer + config := webrtc.Configuration{ + ICEServers: []webrtc.ICEServer{ + { + URLs: []string{"stun:stun.l.google.com:19302"}, + }, + }, + ICETransportPolicy: webrtc.ICETransportPolicyAll, + } + + peerConn, err := m.api.NewPeerConnection(config) + if err != nil { + m.logf("webrtc: failed to create peer connection for incoming offer: %v", err) + return + } + + localDisco := m.conn.DiscoPublicKey() + ps = &webrtcPeerState{ + ep: ep, + peerConn: peerConn, + localDisco: localDisco, + remoteDisco: remoteDisco, + remoteNodeKey: ep.publicKey, + state: webrtcStateConnecting, + createdAt: time.Now(), + } + + // Store peer state + m.mu.Lock() + m.peerConnectionsByEndpoint[ep] = ps + m.peerConnectionsByDisco[remoteDisco] = ps + m.mu.Unlock() + + // Set up connection state handler + peerConn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + m.handleConnectionStateChange(ps, state) + }) + + // Set up ICE candidate handler + peerConn.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate != nil { + m.handleLocalICECandidate(ps, candidate) + } + }) + + // Set up data channel handler (for answerer, we wait for the data channel from offerer). + peerConn.OnDataChannel(func(dc *webrtc.DataChannel) { + m.logf("webrtc: received data channel from peer %v", remoteDisco.ShortString()) + ps.dataChannel = dc + + setOnError(dc, func(err error) { + m.logf("webrtc: data channel error for peer %v: %v", remoteDisco.ShortString(), err) + }) + + dc.OnOpen(func() { + if rwc, err := dc.Detach(); err == nil { + ps.dcRW.Store(&dataChannelRW{rwc}) + go m.runDataChannelReader(ps, rwc) + } else { + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + m.deliverWebRTCMsg(ps, msg.Data) + }) + } + m.logf("webrtc: data channel opened for peer %v", remoteDisco.ShortString()) + m.connectionReadyCh <- webrtcConnectionReadyEvent{ + remoteDisco: remoteDisco, + ep: ep, + } + }) + }) + } + + if err := ps.peerConn.SetRemoteDescription(*offer); err != nil { + m.logf("webrtc: failed to set remote description: %v", err) + return + } + + // Create answer + answer, err := ps.peerConn.CreateAnswer(nil) + if err != nil { + m.logf("webrtc: failed to create answer: %v", err) + return + } + + if err := ps.peerConn.SetLocalDescription(answer); err != nil { + m.logf("webrtc: failed to set local description: %v", err) + return + } + + // Send answer via signaling + if err := m.signalingClient.Answer(ps.localDisco.String(), remoteDisco.String(), &answer); err != nil { + m.logf("webrtc: failed to send answer: %v", err) + return + } + + m.logf("webrtc: sent answer to peer %v", remoteDisco.ShortString()) +} + +// handleRemoteAnswer processes an incoming answer from a peer. +func (m *webrtcManager) handleRemoteAnswer(remoteDisco key.DiscoPublic, answer *webrtc.SessionDescription) { + m.mu.Lock() + ps, exists := m.peerConnectionsByDisco[remoteDisco] + m.mu.Unlock() + + if !exists { + m.logf("webrtc: received answer from unknown peer %v", remoteDisco.ShortString()) + return + } + + if err := ps.peerConn.SetRemoteDescription(*answer); err != nil { + m.logf("webrtc: failed to set remote description: %v", err) + return + } + + m.logf("webrtc: set remote description for peer %v", remoteDisco.ShortString()) +} + +// handleRemoteCandidate processes an incoming ICE candidate from a peer. +func (m *webrtcManager) handleRemoteCandidate(remoteDisco key.DiscoPublic, candidate *webrtc.ICECandidateInit) { + m.mu.Lock() + ps, exists := m.peerConnectionsByDisco[remoteDisco] + m.mu.Unlock() + + if !exists { + m.logf("webrtc: received candidate from unknown peer %v", remoteDisco.ShortString()) + return + } + + // Try to extract the remote address from the candidate string + // Candidate format: "candidate:... udp ... typ ..." + if candidate.Candidate != "" { + if addr := parseICECandidateAddr(candidate.Candidate); addr.IsValid() { + m.mu.Lock() + ps.remoteAddr = addr + m.mu.Unlock() + m.logf("webrtc: peer %v candidate address: %v", remoteDisco.ShortString(), addr) + } + } + + if err := ps.peerConn.AddICECandidate(*candidate); err != nil { + m.logf("webrtc: failed to add ICE candidate: %v", err) + return + } + + m.logf("webrtc: added ICE candidate for peer %v", remoteDisco.ShortString()) +} + +// parseICECandidateAddr extracts the IP:port from an ICE candidate SDP string. +// Example candidate: "candidate:1234 1 udp 2130706431 192.168.1.100 54321 typ host" +func parseICECandidateAddr(candidate string) netip.AddrPort { + fields := strings.Fields(candidate) + // Format: candidate: typ + if len(fields) < 7 { + return netip.AddrPort{} + } + + ip := fields[4] + port := fields[5] + + addr, err := netip.ParseAddr(ip) + if err != nil { + return netip.AddrPort{} + } + + var portNum uint16 + if _, err := fmt.Sscanf(port, "%d", &portNum); err != nil { + return netip.AddrPort{} + } + + return netip.AddrPortFrom(addr, portNum) +} + +// handleLocalICECandidate sends a local ICE candidate to a peer via signaling. +func (m *webrtcManager) handleLocalICECandidate(ps *webrtcPeerState, candidate *webrtc.ICECandidate) { + candidateInit := candidate.ToJSON() + if err := m.signalingClient.Candidate(ps.localDisco.String(), ps.remoteDisco.String(), &candidateInit); err != nil { + m.logf("webrtc: failed to send candidate: %v", err) + return + } + + m.logf("webrtc: sent ICE candidate to peer %v", ps.remoteDisco.ShortString()) +} + +// handleConnectionStateChange handles WebRTC connection state changes. +func (m *webrtcManager) handleConnectionStateChange(ps *webrtcPeerState, state webrtc.PeerConnectionState) { + m.logf("webrtc: connection state changed to %s for peer %v", state.String(), ps.remoteDisco.ShortString()) + + m.mu.Lock() + defer m.mu.Unlock() + + switch state { + case webrtc.PeerConnectionStateConnected: + ps.state = webrtcStateConnected + // Log the selected ICE candidate pair so we can confirm the actual + // data path (LAN host candidate vs. STUN server-reflexive vs. relay). + go func() { + cp, err := ps.peerConn.SCTP().Transport().ICETransport().GetSelectedCandidatePair() + if err != nil || cp == nil { + m.logf("webrtc: peer %v connected (selected candidate pair unavailable: %v)", + ps.remoteDisco.ShortString(), err) + return + } + m.logf("webrtc: peer %v connected via %s:%d → %s:%d (local %s, remote %s)", + ps.remoteDisco.ShortString(), + cp.Local.Address, cp.Local.Port, + cp.Remote.Address, cp.Remote.Port, + cp.Local.Typ, cp.Remote.Typ) + }() + case webrtc.PeerConnectionStateFailed: + ps.state = webrtcStateFailed + ps.lastError = errors.New("connection failed") + ps.dcRW.Store(nil) + case webrtc.PeerConnectionStateClosed: + ps.state = webrtcStateClosed + ps.dcRW.Store(nil) + case webrtc.PeerConnectionStateDisconnected: + // Transient state, keep current state + } +} + +// handleConnectionReady marks a WebRTC connection as ready and updates endpoint. +func (m *webrtcManager) handleConnectionReady(event webrtcConnectionReadyEvent) { + m.logf("webrtc: connection ready for peer %v", event.remoteDisco.ShortString()) + + // Update endpoint to use WebRTC path + event.ep.mu.Lock() + defer event.ep.mu.Unlock() + + // Use a fixed port number for WebRTC connections (similar to DERP) + // The magic IP identifies this as WebRTC, not UDP + webrtcAddr := addrQuality{ + epAddr: epAddr{ + ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345), + }, + latency: 0, // Will be determined by disco pings, same as DERP + } + + // Set as bestAddr if better than current + now := mono.Now() + if betterAddr(webrtcAddr, event.ep.bestAddr) { + event.ep.bestAddr = webrtcAddr + event.ep.bestAddrAt = now + event.ep.trustBestAddrUntil = now.Add(5 * time.Minute) + m.logf("webrtc: updated endpoint %v with WebRTC path", event.ep.nodeAddr) + } +} + +// sendPacket sends a packet over a WebRTC data channel. +// The hot path is lock-free: we take a read-lock (not write-lock) to look up +// the peer state, then do an atomic load for the detached channel. Multiple +// concurrent senders for different peers never contend. +func (m *webrtcManager) sendPacket(disco key.DiscoPublic, b []byte) error { + m.mu.RLock() + ps, ok := m.peerConnectionsByDisco[disco] + m.mu.RUnlock() + if !ok { + return errors.New("no WebRTC connection") + } + + // Native path: DetachDataChannels was enabled; use the raw io.ReadWriteCloser. + if rw := ps.dcRW.Load(); rw != nil { + if _, err := rw.Write(b); err != nil { + return fmt.Errorf("send failed: %w", err) + } + return nil + } + + // JS/fallback path: use DataChannel.Send() directly. + dc := ps.dataChannel + if dc == nil || dc.ReadyState() != webrtc.DataChannelStateOpen { + return errors.New("data channel not ready") + } + return dc.Send(b) +} + +// receiveWebRTC reads packets from the WebRTC receive channel. +// It is called by wireguard-go through the conn.Bind interface. +// It blocks until at least one packet is available, then drains as many +// additional packets as are immediately ready (up to len(buffs)). +func (c *connBind) receiveWebRTC(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + // Block until the first packet arrives (or the channel is closed). + wr, ok := <-c.webrtcRecvCh + if !ok || c.isClosed() { + return 0, net.ErrClosed + } + num := 0 + n, ep := c.processWebRTCReadResult(wr, buffs[num]) + if n > 0 { + sizes[num] = n + eps[num] = ep + num++ + } + // Drain any additional packets that are immediately available. + for num < len(buffs) { + select { + case wr, ok = <-c.webrtcRecvCh: + if !ok || c.isClosed() { + if num > 0 { + return num, nil + } + return 0, net.ErrClosed + } + n, ep = c.processWebRTCReadResult(wr, buffs[num]) + if n > 0 { + sizes[num] = n + eps[num] = ep + num++ + } + default: + return num, nil + } + } + return num, nil +} diff --git a/wgengine/magicsock/webrtc_base_js.go b/wgengine/magicsock/webrtc_base_js.go new file mode 100644 index 000000000..de1f0af20 --- /dev/null +++ b/wgengine/magicsock/webrtc_base_js.go @@ -0,0 +1,37 @@ +//go:build js + +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "github.com/pion/webrtc/v4" + "tailscale.com/types/key" +) + +func newWebRTCManagerBase(c *Conn, signalingURL string) *webrtcManager { + // Configure WebRTC with STUN only + settingEngine := webrtc.SettingEngine{} + + // Create API with setting engine + api := webrtc.NewAPI( + webrtc.WithSettingEngine(settingEngine), + ) + + return &webrtcManager{ + logf: c.logf, + conn: c, + peerConnectionsByEndpoint: make(map[*endpoint]*webrtcPeerState), + peerConnectionsByDisco: make(map[key.DiscoPublic]*webrtcPeerState), + startConnectionCh: make(chan *endpoint, 256), + connectionReadyCh: make(chan webrtcConnectionReadyEvent, 16), + closeCh: make(chan struct{}), + runLoopStoppedCh: make(chan struct{}), + api: api, + } +} + +func setOnError(dc *webrtc.DataChannel, fn func(error)) { + // NO-OP... *webrtc.DataChannel does not have OnError for js. +} diff --git a/wgengine/magicsock/webrtc_base_native.go b/wgengine/magicsock/webrtc_base_native.go new file mode 100644 index 000000000..52433dd87 --- /dev/null +++ b/wgengine/magicsock/webrtc_base_native.go @@ -0,0 +1,70 @@ +//go:build !js + +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "time" + + "github.com/pion/webrtc/v4" + "tailscale.com/types/key" +) + +func newWebRTCManagerBase(c *Conn, signalingURL string) *webrtcManager { + settingEngine := webrtc.SettingEngine{} + + // Use a 16 MiB SCTP receive buffer. The pion default (~32 KiB) becomes the + // bottleneck at high throughput because SCTP's flow-control window is bounded + // by this value. + settingEngine.SetSCTPMaxReceiveBufferSize(16 * 1024 * 1024) + + // Enlarge the DTLS replay-protection window. The default (64) causes + // legitimate packets to be dropped as duplicates when the sender gets ahead + // of the receiver by more than 64 packets, which happens easily at Gbps speeds. + settingEngine.SetDTLSReplayProtectionWindow(8192) + + // Lower the SCTP retransmission timeout ceiling. The default (1s+) causes + // SCTP's congestion control to stall for a full second after any loss event, + // which is catastrophic for throughput on a low-latency P2P link. 100ms is + // still conservative but recovers much faster. + settingEngine.SetSCTPRTOMax(100 * time.Millisecond) + + // DetachDataChannels lets us call dc.Detach() to get a raw io.ReadWriteCloser + // instead of using OnMessage callbacks. The callback path allocates a new + // DataChannelMessage struct and fires a goroutine wakeup per packet. The + // detached path lets us Read() directly into pre-allocated buffers in a + // tight goroutine loop, matching how the UDP receive path works. + settingEngine.DetachDataChannels() + + // SCTP includes a CRC32c checksum on every chunk. DTLS already provides + // both integrity and authenticity for all data, so the SCTP checksum is + // redundant CPU work. Zero-checksum mode (RFC 9260) removes it. + settingEngine.EnableSCTPZeroChecksum(true) + + // Create MediaEngine (required even though we only use DataChannel) + mediaEngine := &webrtc.MediaEngine{} + + // Create API with setting engine + api := webrtc.NewAPI( + webrtc.WithSettingEngine(settingEngine), + webrtc.WithMediaEngine(mediaEngine), + ) + + return &webrtcManager{ + logf: c.logf, + conn: c, + peerConnectionsByEndpoint: make(map[*endpoint]*webrtcPeerState), + peerConnectionsByDisco: make(map[key.DiscoPublic]*webrtcPeerState), + startConnectionCh: make(chan *endpoint, 256), + connectionReadyCh: make(chan webrtcConnectionReadyEvent, 16), + closeCh: make(chan struct{}), + runLoopStoppedCh: make(chan struct{}), + api: api, + } +} + +func setOnError(dc *webrtc.DataChannel, fn func(error)) { + dc.OnError(fn) +} diff --git a/wgengine/magicsock/webrtc_integration_test.go b/wgengine/magicsock/webrtc_integration_test.go new file mode 100644 index 000000000..5e5045b15 --- /dev/null +++ b/wgengine/magicsock/webrtc_integration_test.go @@ -0,0 +1,435 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/http" + "net/http/httptest" + "net/netip" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/pion/webrtc/v4" + "tailscale.com/rtclib" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// testSignalHandler is a test implementation of rtclib.SignalHandler +type testSignalHandler struct { + offerCount int + answerCount int + candidateCount int + t *testing.T + mu sync.Mutex +} + +func (h *testSignalHandler) HandleOffer(from, to string, offer *webrtc.SessionDescription) { + h.mu.Lock() + defer h.mu.Unlock() + h.offerCount++ + h.t.Logf("Received offer from %s to %s", from, to) +} + +func (h *testSignalHandler) HandleAnswer(from, to string, answer *webrtc.SessionDescription) { + h.mu.Lock() + defer h.mu.Unlock() + h.answerCount++ + h.t.Logf("Received answer from %s to %s", from, to) +} + +func (h *testSignalHandler) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) { + h.mu.Lock() + defer h.mu.Unlock() + h.candidateCount++ + h.t.Logf("Received candidate from %s to %s", from, to) +} + +// TestWebRTCIntegration_MockSignalingServer tests WebRTC with a mock signaling server +func TestWebRTCIntegration_MockSignalingServer(t *testing.T) { + // Create mock signaling server + server := newMockSignalingServer(t) + defer server.Close() + + t.Logf("Mock signaling server running at %s", server.URL) + + // Verify server accepts connections + client := newSignalingClient(server.URL, t.Logf) + + handler := &testSignalHandler{t: t} + err := client.Start(handler) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + defer client.Close() + + // Wait for connection + time.Sleep(100 * time.Millisecond) + + // Send a test message + disco1 := key.NewDisco().Public() + disco2 := key.NewDisco().Public() + + err = client.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: "test", + }) + if err != nil { + t.Fatalf("Failed to send offer: %v", err) + } + + // Give time for message to be processed + time.Sleep(100 * time.Millisecond) + + // Verify the server received the message + serverMsgs := server.GetReceivedMessages() + + if len(serverMsgs) == 0 { + t.Error("Server did not receive any messages") + } else { + t.Logf("Server received %d messages", len(serverMsgs)) + for i, msg := range serverMsgs { + t.Logf("Message %d: type=%s from=%s to=%s", i, msg.Type, msg.From, msg.To) + } + } +} + +// TestWebRTCIntegration_MessageRelay tests message relay through signaling server +func TestWebRTCIntegration_MessageRelay(t *testing.T) { + server := newMockSignalingServer(t) + defer server.Close() + + // Create two clients + handler1 := &testSignalHandler{t: t} + handler2 := &testSignalHandler{t: t} + + client1 := newSignalingClient(server.URL, func(format string, args ...any) { + t.Logf("[Client1] "+format, args...) + }) + + client2 := newSignalingClient(server.URL, func(format string, args ...any) { + t.Logf("[Client2] "+format, args...) + }) + + err := client1.Start(handler1) + if err != nil { + t.Fatalf("Failed to start client1: %v", err) + } + defer client1.Close() + + err = client2.Start(handler2) + if err != nil { + t.Fatalf("Failed to start client2: %v", err) + } + defer client2.Close() + + // Wait for both to connect + time.Sleep(200 * time.Millisecond) + + // Client 1 sends offer to Client 2 + disco1 := key.NewDisco().Public() + disco2 := key.NewDisco().Public() + + if err := client1.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: "v=0...", + }); err != nil { + t.Fatalf("Client1 failed to send offer: %v", err) + } + + // Wait for message relay + time.Sleep(200 * time.Millisecond) + + // Verify client2 received the offer + handler2.mu.Lock() + c2offers := handler2.offerCount + handler2.mu.Unlock() + + if c2offers > 0 { + t.Logf("Client2 received %d offers (relay working)", c2offers) + } else { + t.Log("Client2 did not receive offers (relay may need proper routing)") + } + + // Client 2 sends answer back to Client 1 + if err := client2.Answer(disco2.String(), disco1.String(), &webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: "v=0...", + }); err != nil { + t.Fatalf("Client2 failed to send answer: %v", err) + } + + // Wait for message relay + time.Sleep(200 * time.Millisecond) + + // Log final message counts + handler1.mu.Lock() + c1answers := handler1.answerCount + handler1.mu.Unlock() + + handler2.mu.Lock() + c2FinalOffers := handler2.offerCount + handler2.mu.Unlock() + + t.Logf("Final message counts: Client1 answers=%d, Client2 offers=%d", c1answers, c2FinalOffers) + t.Log("Integration test completed successfully") +} + +// TestWebRTCIntegration_SignalingFlow tests the complete signaling flow +func TestWebRTCIntegration_SignalingFlow(t *testing.T) { + server := newMockSignalingServer(t) + defer server.Close() + + client := newSignalingClient(server.URL, t.Logf) + handler := &testSignalHandler{t: t} + + err := client.Start(handler) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + defer client.Close() + + time.Sleep(100 * time.Millisecond) + + disco1 := key.NewDisco().Public() + disco2 := key.NewDisco().Public() + + // Simulate complete signaling flow + steps := []struct { + name string + fn func() error + }{ + { + name: "send_offer", + fn: func() error { + return client.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: "v=0 offer", + }) + }, + }, + { + name: "send_answer", + fn: func() error { + return client.Answer(disco2.String(), disco1.String(), &webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: "v=0 answer", + }) + }, + }, + { + name: "send_candidate", + fn: func() error { + return client.Candidate(disco1.String(), disco2.String(), &webrtc.ICECandidateInit{ + Candidate: "test", + }) + }, + }, + } + + for _, step := range steps { + t.Logf("Step: %s", step.name) + if err := step.fn(); err != nil { + t.Errorf("Failed to execute %s: %v", step.name, err) + } + time.Sleep(50 * time.Millisecond) + } + + t.Logf("Signaling flow completed with %d steps", len(steps)) +} + +// mockSignalingServer is a simple WebSocket server that relays signaling messages +type mockSignalingServer struct { + *httptest.Server + upgrader websocket.Upgrader + + mu sync.Mutex + clients map[*websocket.Conn]bool + messages []rtclib.SignalingMessage + t *testing.T +} + +func newMockSignalingServer(t *testing.T) *mockSignalingServer { + s := &mockSignalingServer{ + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + }, + clients: make(map[*websocket.Conn]bool), + messages: make([]rtclib.SignalingMessage, 0), + t: t, + } + + s.Server = httptest.NewServer(http.HandlerFunc(s.handleWebSocket)) + + // Convert http:// to ws:// + s.Server.URL = "ws" + s.Server.URL[4:] + + return s +} + +func (s *mockSignalingServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + s.t.Logf("Upgrade error: %v", err) + return + } + + s.mu.Lock() + s.clients[conn] = true + s.mu.Unlock() + + defer func() { + s.mu.Lock() + delete(s.clients, conn) + s.mu.Unlock() + conn.Close() + }() + + s.t.Logf("Client connected, total clients: %d", len(s.clients)) + + for { + var msg rtclib.SignalingMessage + if err := conn.ReadJSON(&msg); err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + s.t.Logf("Read error: %v", err) + } + break + } + + s.t.Logf("Server received: type=%s from=%s to=%s", msg.Type, msg.From, msg.To) + + s.mu.Lock() + s.messages = append(s.messages, msg) + + // Relay to all other clients (simple broadcast) + for client := range s.clients { + if client != conn { + if err := client.WriteJSON(msg); err != nil { + s.t.Logf("Relay error: %v", err) + } + } + } + s.mu.Unlock() + } +} + +func (s *mockSignalingServer) GetReceivedMessages() []rtclib.SignalingMessage { + s.mu.Lock() + defer s.mu.Unlock() + + msgs := make([]rtclib.SignalingMessage, len(s.messages)) + copy(msgs, s.messages) + return msgs +} + +func (s *mockSignalingServer) Close() { + s.mu.Lock() + for conn := range s.clients { + conn.Close() + } + s.mu.Unlock() + + s.Server.Close() +} + +// BenchmarkWebRTCSignaling benchmarks signaling message throughput +func BenchmarkWebRTCSignaling(b *testing.B) { + server := newMockSignalingServer(&testing.T{}) + defer server.Close() + + client := newSignalingClient(server.URL, func(string, ...any) {}) + handler := &testSignalHandler{t: &testing.T{}} + err := client.Start(handler) + if err != nil { + b.Fatalf("Failed to start client: %v", err) + } + defer client.Close() + + time.Sleep(100 * time.Millisecond) + + disco1 := key.NewDisco().Public() + disco2 := key.NewDisco().Public() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := client.Candidate(disco1.String(), disco2.String(), &webrtc.ICECandidateInit{ + Candidate: "test", + }); err != nil { + b.Errorf("Send failed: %v", err) + } + } +} + +// TestWebRTCPacketFlow tests packet flow simulation +func TestWebRTCPacketFlow(t *testing.T) { + // Create a mock webrtcReadResult to simulate packet reception + nodeKey := key.NewNode().Public() + testPacket := []byte("test wireguard packet") + + result := webrtcReadResult{ + n: len(testPacket), + src: nodeKey, + copyBuf: func(dst []byte) int { + return copy(dst, testPacket) + }, + } + + // Verify packet can be copied + buf := make([]byte, 1024) + n := result.copyBuf(buf) + + if n != len(testPacket) { + t.Errorf("copyBuf returned %d bytes, want %d", n, len(testPacket)) + } + + if string(buf[:n]) != string(testPacket) { + t.Errorf("Packet data mismatch: got %q, want %q", buf[:n], testPacket) + } + + t.Logf("Packet flow test passed: %d bytes", n) +} + +// TestWebRTCPathSelection tests path selection with WebRTC in the mix +func TestWebRTCPathSelection(t *testing.T) { + tests := []struct { + name string + paths []addrQuality + wantBest string + }{ + { + name: "direct_beats_all", + paths: []addrQuality{ + {epAddr: epAddr{ap: netip.MustParseAddrPort("1.2.3.4:1234")}}, + {epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}}, + {epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1)}}, + }, + wantBest: "1.2.3.4:1234", + }, + { + name: "webrtc_beats_derp", + paths: []addrQuality{ + {epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}}, + {epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1)}}, + }, + wantBest: "127.3.3.41:12345", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + best := tt.paths[0] + for _, path := range tt.paths[1:] { + if betterAddr(path, best) { + best = path + } + } + + if best.ap.String() != tt.wantBest { + t.Errorf("Best path = %v, want %v", best.ap, tt.wantBest) + } + }) + } +} diff --git a/wgengine/magicsock/webrtc_signaling.go b/wgengine/magicsock/webrtc_signaling.go new file mode 100644 index 000000000..31463aaa3 --- /dev/null +++ b/wgengine/magicsock/webrtc_signaling.go @@ -0,0 +1,296 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/pion/webrtc/v4" + "tailscale.com/rtclib" + "tailscale.com/types/logger" +) + +// Ensure signalingClient implements rtclib.Signaller interface. +var _ rtclib.Signaller = (*signalingClient)(nil) + +// signalingClient manages WebSocket connection to signaling server. +type signalingClient struct { + url string + logf logger.Logf + conn *websocket.Conn + connMu sync.Mutex + + // Message handler + handler rtclib.SignalHandler + + // Control channels + ctx context.Context + ctxCancel context.CancelFunc + sendCh chan *rtclib.SignalingMessage + closedCh chan struct{} + + // Reconnection state + reconnectDelay time.Duration + maxDelay time.Duration +} + +// newSignalingClient creates a new signaling client. +func newSignalingClient(url string, logf logger.Logf) *signalingClient { + ctx, cancel := context.WithCancel(context.Background()) + return &signalingClient{ + url: url, + logf: logf, + ctx: ctx, + ctxCancel: cancel, + sendCh: make(chan *rtclib.SignalingMessage, 16), + closedCh: make(chan struct{}), + reconnectDelay: time.Second, + maxDelay: 30 * time.Second, + } +} + +// Start begins the signaling client's connection and message loops. +func (sc *signalingClient) Start(handler rtclib.SignalHandler) error { + sc.handler = handler + if err := sc.connect(); err != nil { + return fmt.Errorf("initial signaling connection failed: %w", err) + } + go sc.runLoop() + return nil +} + +// connect establishes WebSocket connection to signaling server. +func (sc *signalingClient) connect() error { + sc.connMu.Lock() + defer sc.connMu.Unlock() + + if sc.conn != nil { + return nil // already connected + } + + conn, _, err := websocket.Dial(sc.ctx, sc.url, nil) + if err != nil { + return fmt.Errorf("websocket dial failed: %w", err) + } + + sc.conn = conn + sc.reconnectDelay = time.Second // reset backoff on successful connection + sc.logf("signaling: connected to %s", sc.url) + return nil +} + +// Close closes the signaling client. +func (sc *signalingClient) Close() error { + // Cancel context to signal all goroutines to stop + sc.ctxCancel() + + // Close the connection to unblock any read/write operations + sc.connMu.Lock() + if sc.conn != nil { + sc.conn.Close(websocket.StatusNormalClosure, "") + } + sc.connMu.Unlock() + + // Wait for runLoop to finish with timeout + select { + case <-sc.closedCh: + case <-time.After(2 * time.Second): + sc.logf("signaling: close timed out, forcing shutdown") + } + + sc.connMu.Lock() + defer sc.connMu.Unlock() + sc.conn = nil + return nil +} + +// send queues a message to be sent to the signaling server. +func (sc *signalingClient) send(msg *rtclib.SignalingMessage) error { + select { + case sc.sendCh <- msg: + return nil + case <-sc.ctx.Done(): + return sc.ctx.Err() + default: + return errors.New("signaling send queue full") + } +} + +// runLoop manages connection lifecycle and message routing. +func (sc *signalingClient) runLoop() { + defer close(sc.closedCh) + + for { + select { + case <-sc.ctx.Done(): + return + default: + } + + // Ensure we're connected + if err := sc.ensureConnected(); err != nil { + sc.logf("signaling: connection failed, retrying in %v: %v", sc.reconnectDelay, err) + select { + case <-time.After(sc.reconnectDelay): + sc.reconnectDelay = min(sc.reconnectDelay*2, sc.maxDelay) + continue + case <-sc.ctx.Done(): + return + } + } + + // Run read/write loops + errCh := make(chan error, 2) + go sc.readLoop(errCh) + go sc.writeLoop(errCh) + + // Wait for error or context cancellation + select { + case err := <-errCh: + sc.logf("signaling: connection error: %v", err) + sc.disconnect() + case <-sc.ctx.Done(): + sc.disconnect() + return + } + } +} + +// ensureConnected ensures connection is established. +func (sc *signalingClient) ensureConnected() error { + sc.connMu.Lock() + connected := sc.conn != nil + sc.connMu.Unlock() + + if connected { + return nil + } + + return sc.connect() +} + +// disconnect closes the current connection. +func (sc *signalingClient) disconnect() { + sc.connMu.Lock() + defer sc.connMu.Unlock() + + if sc.conn != nil { + sc.conn.Close(websocket.StatusNormalClosure, "") + sc.conn = nil + sc.logf("signaling: disconnected") + } +} + +// readLoop reads messages from WebSocket. +func (sc *signalingClient) readLoop(errCh chan<- error) { + for { + sc.connMu.Lock() + conn := sc.conn + sc.connMu.Unlock() + + if conn == nil { + errCh <- errors.New("no connection") + return + } + + var msg rtclib.SignalingMessage + if err := wsjson.Read(sc.ctx, conn, &msg); err != nil { + errCh <- fmt.Errorf("read failed: %w", err) + return + } + + if sc.handler != nil { + switch msg.Type { + case rtclib.MessageTypeOffer: + sc.handler.HandleOffer(msg.From, msg.To, msg.Offer) + case rtclib.MessageTypeAnswer: + sc.handler.HandleAnswer(msg.From, msg.To, msg.Answer) + case rtclib.MessageTypeCandidate: + sc.handler.HandleCandidate(msg.From, msg.To, msg.Candidate) + default: + sc.logf("signaling: unknown message type: %s", msg.Type) + } + } + } +} + +// writeLoop writes messages to WebSocket. +func (sc *signalingClient) writeLoop(errCh chan<- error) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case msg := <-sc.sendCh: + sc.connMu.Lock() + conn := sc.conn + sc.connMu.Unlock() + + if conn == nil { + errCh <- errors.New("no connection") + return + } + + if err := wsjson.Write(sc.ctx, conn, msg); err != nil { + errCh <- fmt.Errorf("write failed: %w", err) + return + } + + case <-ticker.C: + // Send ping to keep connection alive + sc.connMu.Lock() + conn := sc.conn + sc.connMu.Unlock() + + if conn == nil { + errCh <- errors.New("no connection") + return + } + + if err := conn.Ping(sc.ctx); err != nil { + errCh <- fmt.Errorf("ping failed: %w", err) + return + } + + case <-sc.ctx.Done(): + return + } + } +} + +// Offer sends an SDP offer to a peer. +func (sc *signalingClient) Offer(from, to string, offer *webrtc.SessionDescription) error { + return sc.send(&rtclib.SignalingMessage{ + Type: rtclib.MessageTypeOffer, + From: from, + To: to, + Offer: offer, + }) +} + +// Answer sends an SDP answer to a peer. +func (sc *signalingClient) Answer(from, to string, answer *webrtc.SessionDescription) error { + return sc.send(&rtclib.SignalingMessage{ + Type: rtclib.MessageTypeAnswer, + From: from, + To: to, + Answer: answer, + }) +} + +// Candidate sends an ICE candidate to a peer. +func (sc *signalingClient) Candidate(from, to string, candidate *webrtc.ICECandidateInit) error { + return sc.send(&rtclib.SignalingMessage{ + Type: rtclib.MessageTypeCandidate, + From: from, + To: to, + Candidate: candidate, + }) +} diff --git a/wgengine/magicsock/webrtc_test.go b/wgengine/magicsock/webrtc_test.go new file mode 100644 index 000000000..aace5997d --- /dev/null +++ b/wgengine/magicsock/webrtc_test.go @@ -0,0 +1,323 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/pion/webrtc/v4" + "tailscale.com/rtclib" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// TestSignalingMessageEncoding tests JSON encoding/decoding of signaling messages +func TestSignalingMessageEncoding(t *testing.T) { + disco1 := key.NewDisco() + disco2 := key.NewDisco() + + tests := []struct { + name string + msg rtclib.SignalingMessage + }{ + { + name: "offer", + msg: rtclib.SignalingMessage{ + Type: rtclib.MessageTypeOffer, + From: disco1.Public().String(), + To: disco2.Public().String(), + Offer: &webrtc.SessionDescription{Type: webrtc.SDPTypeOffer, SDP: "v=0..."}, + }, + }, + { + name: "answer", + msg: rtclib.SignalingMessage{ + Type: rtclib.MessageTypeAnswer, + From: disco2.Public().String(), + To: disco1.Public().String(), + Answer: &webrtc.SessionDescription{Type: webrtc.SDPTypeAnswer, SDP: "v=0..."}, + }, + }, + { + name: "candidate", + msg: rtclib.SignalingMessage{ + Type: rtclib.MessageTypeCandidate, + From: disco1.Public().String(), + To: disco2.Public().String(), + Candidate: &webrtc.ICECandidateInit{Candidate: "..."}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Encode + data, err := json.Marshal(tt.msg) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // Decode + var decoded rtclib.SignalingMessage + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + // Verify + if decoded.Type != tt.msg.Type { + t.Errorf("Type mismatch: got %v, want %v", decoded.Type, tt.msg.Type) + } + if decoded.From != tt.msg.From { + t.Errorf("From mismatch: got %v, want %v", decoded.From, tt.msg.From) + } + if decoded.To != tt.msg.To { + t.Errorf("To mismatch: got %v, want %v", decoded.To, tt.msg.To) + } + }) + } +} + +// mockSignalHandler is a test implementation of rtclib.SignalHandler +type mockSignalHandler struct { + offerCount int + answerCount int + candidateCount int + mu sync.Mutex +} + +func (m *mockSignalHandler) HandleOffer(from, to string, offer *webrtc.SessionDescription) { + m.mu.Lock() + defer m.mu.Unlock() + m.offerCount++ +} + +func (m *mockSignalHandler) HandleAnswer(from, to string, answer *webrtc.SessionDescription) { + m.mu.Lock() + defer m.mu.Unlock() + m.answerCount++ +} + +func (m *mockSignalHandler) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) { + m.mu.Lock() + defer m.mu.Unlock() + m.candidateCount++ +} + +// TestSignalingClientReconnect tests reconnection with backoff +func TestSignalingClientReconnect(t *testing.T) { + var connectCount int + var mu sync.Mutex + + // Mock WebSocket server that closes connections after accepting them + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + connectCount++ + mu.Unlock() + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + // Immediately close to force reconnection + conn.Close() + })) + defer server.Close() + + // Convert http:// to ws:// + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + + client := newSignalingClient(wsURL, t.Logf) + handler := &mockSignalHandler{} + err := client.Start(handler) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + defer client.Close() + + // Wait for a few reconnection attempts + time.Sleep(3 * time.Second) + + mu.Lock() + count := connectCount + mu.Unlock() + + // Should have attempted to connect multiple times + if count < 2 { + t.Errorf("Expected multiple reconnection attempts, got %d", count) + } + + t.Logf("Reconnected %d times", count) +} + +// TestWebRTCMagicIP tests the WebRTC magic IP constant +func TestWebRTCMagicIP(t *testing.T) { + if tailcfg.WebRTCMagicIPAddr.String() != "127.3.3.41" { + t.Errorf("WebRTC magic IP = %v, want 127.3.3.41", tailcfg.WebRTCMagicIPAddr) + } + + // Verify it's different from DERP magic IP + if tailcfg.WebRTCMagicIPAddr == tailcfg.DerpMagicIPAddr { + t.Error("WebRTC magic IP should be different from DERP magic IP") + } +} + +// TestWebRTCPathPriority tests path preference logic +func TestWebRTCPathPriority(t *testing.T) { + directV4 := addrQuality{ + epAddr: epAddr{ + ap: netip.MustParseAddrPort("192.168.1.100:41641"), + }, + latency: 10 * time.Millisecond, + } + + webrtc := addrQuality{ + epAddr: epAddr{ + ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345), + }, + latency: 50 * time.Millisecond, + } + + derp := addrQuality{ + epAddr: epAddr{ + ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1), + }, + latency: 100 * time.Millisecond, + } + + tests := []struct { + name string + a, b addrQuality + want bool // true if a is better than b + }{ + { + name: "direct beats WebRTC", + a: directV4, + b: webrtc, + want: true, + }, + { + name: "WebRTC beats DERP", + a: webrtc, + b: derp, + want: true, + }, + { + name: "direct beats DERP", + a: directV4, + b: derp, + want: true, + }, + { + name: "DERP loses to WebRTC", + a: derp, + b: webrtc, + want: false, + }, + { + name: "WebRTC loses to direct", + a: webrtc, + b: directV4, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := betterAddr(tt.a, tt.b) + if got != tt.want { + t.Errorf("betterAddr(%v, %v) = %v, want %v", tt.a.ap, tt.b.ap, got, tt.want) + } + }) + } +} + +// TestWebRTCReadResult tests webrtcReadResult structure +func TestWebRTCReadResult(t *testing.T) { + nodeKey := key.NewNode() + testData := []byte("test packet data") + + result := webrtcReadResult{ + n: len(testData), + src: nodeKey.Public(), + copyBuf: func(dst []byte) int { + return copy(dst, testData) + }, + } + + // Test copyBuf + buf := make([]byte, 100) + n := result.copyBuf(buf) + if n != len(testData) { + t.Errorf("copyBuf returned %d, want %d", n, len(testData)) + } + if string(buf[:n]) != string(testData) { + t.Errorf("copyBuf data = %q, want %q", buf[:n], testData) + } + + // Test fields + if result.n != len(testData) { + t.Errorf("result.n = %d, want %d", result.n, len(testData)) + } + if result.src != nodeKey.Public() { + t.Errorf("result.src mismatch") + } +} + +// TestDiscoRXPathWebRTC tests the WebRTC disco path constant +func TestDiscoRXPathWebRTC(t *testing.T) { + if discoRXPathWebRTC != "WebRTC" { + t.Errorf("discoRXPathWebRTC = %q, want %q", discoRXPathWebRTC, "WebRTC") + } + + // Verify it's different from other paths + if discoRXPathWebRTC == discoRXPathDERP { + t.Error("WebRTC path should be different from DERP path") + } + if discoRXPathWebRTC == discoRXPathUDP { + t.Error("WebRTC path should be different from UDP path") + } +} + +// TestWebRTCMetrics tests that WebRTC metrics are properly defined +func TestWebRTCMetrics(t *testing.T) { + // Test that metric variables exist (they're package-level variables) + if metricRecvDataPacketsWebRTC == nil { + t.Error("metricRecvDataPacketsWebRTC should be initialized") + } + if metricRecvDataBytesWebRTC == nil { + t.Error("metricRecvDataBytesWebRTC should be initialized") + } + if metricSendDataPacketsWebRTC == nil { + t.Error("metricSendDataPacketsWebRTC should be initialized") + } + if metricSendDataBytesWebRTC == nil { + t.Error("metricSendDataBytesWebRTC should be initialized") + } + + t.Log("WebRTC metrics are properly defined") +} + +// TestPathWebRTCConstant tests the PathWebRTC constant +func TestPathWebRTCConstant(t *testing.T) { + if PathWebRTC != "webrtc" { + t.Errorf("PathWebRTC = %q, want %q", PathWebRTC, "webrtc") + } + + // Verify it's different from other paths + if PathWebRTC == PathDERP { + t.Error("PathWebRTC should be different from PathDERP") + } + if PathWebRTC == PathDirectIPv4 { + t.Error("PathWebRTC should be different from PathDirectIPv4") + } +}