mirror of
https://github.com/tailscale/tailscale.git
synced 2026-05-05 12:16:44 +02:00
wgengine/magicsock: add webrtc path to magicsock (experimental)
This commit is contained in:
parent
7477a6ee47
commit
ec53be090f
@ -23,6 +23,7 @@ import (
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
@ -196,7 +197,19 @@ func runStatus(ctx context.Context, args []string) error {
|
||||
if relay != "" && ps.CurAddr == "" && ps.PeerRelay == "" {
|
||||
f("relay %q", relay)
|
||||
} else if ps.CurAddr != "" {
|
||||
f("direct %s", ps.CurAddr)
|
||||
// Check if this is a WebRTC connection (address matches WebRTC magic IP)
|
||||
if strings.HasPrefix(ps.CurAddr, tailcfg.WebRTCMagicIP) {
|
||||
// Extract the actual remote address from CurAddr, which for a WebRTC path
|
||||
// is of the form "${WEBRTC_MAGIC_IP}:${DUMMY_PORT} (${REAL_IP_AND_PORT})"
|
||||
// e.g. "127.3.3.41:12345 (134.209.53.229:37792)".
|
||||
realRemoteAddr := "<UNKNOWN>"
|
||||
if idx := strings.Index(ps.CurAddr, " ("); idx > 0 {
|
||||
realRemoteAddr = ps.CurAddr[idx+2 : len(ps.CurAddr)-1]
|
||||
}
|
||||
f("webrtc %s", realRemoteAddr)
|
||||
} else {
|
||||
f("direct %s", ps.CurAddr)
|
||||
}
|
||||
} else if ps.PeerRelay != "" {
|
||||
f("peer-relay %s", ps.PeerRelay)
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
package disco
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -51,6 +52,14 @@ const (
|
||||
TypeCallMeMaybeVia = MessageType(0x07)
|
||||
TypeAllocateUDPRelayEndpointRequest = MessageType(0x08)
|
||||
TypeAllocateUDPRelayEndpointResponse = MessageType(0x09)
|
||||
// TypeWebRTCOffer, TypeWebRTCAnswer, and TypeWebRTCICECandidate carry
|
||||
// WebRTC signaling payloads between Tailscale peers via DERP, eliminating
|
||||
// the need for an external signaling server. Each message's payload is a
|
||||
// JSON-encoded WebRTC type (*webrtc.SessionDescription or
|
||||
// *webrtc.ICECandidateInit).
|
||||
TypeWebRTCOffer = MessageType(0x0A)
|
||||
TypeWebRTCAnswer = MessageType(0x0B)
|
||||
TypeWebRTCICECandidate = MessageType(0x0C)
|
||||
)
|
||||
|
||||
const v0 = byte(0)
|
||||
@ -103,6 +112,12 @@ func Parse(p []byte) (Message, error) {
|
||||
return parseAllocateUDPRelayEndpointRequest(ver, p)
|
||||
case TypeAllocateUDPRelayEndpointResponse:
|
||||
return parseAllocateUDPRelayEndpointResponse(ver, p)
|
||||
case TypeWebRTCOffer:
|
||||
return parseWebRTCOffer(ver, p)
|
||||
case TypeWebRTCAnswer:
|
||||
return parseWebRTCAnswer(ver, p)
|
||||
case TypeWebRTCICECandidate:
|
||||
return parseWebRTCICECandidate(ver, p)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type 0x%02x", byte(t))
|
||||
}
|
||||
@ -278,6 +293,48 @@ func parsePong(ver uint8, p []byte) (m *Pong, err error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// WebRTCOffer is sent only over DERP to deliver a WebRTC SDP offer to a peer.
|
||||
// Payload is a JSON-encoded *webrtc.SessionDescription.
|
||||
type WebRTCOffer struct{ Payload []byte }
|
||||
|
||||
func (m *WebRTCOffer) AppendMarshal(b []byte) []byte {
|
||||
ret, p := appendMsgHeader(b, TypeWebRTCOffer, v0, len(m.Payload))
|
||||
copy(p, m.Payload)
|
||||
return ret
|
||||
}
|
||||
|
||||
func parseWebRTCOffer(_ uint8, p []byte) (*WebRTCOffer, error) {
|
||||
return &WebRTCOffer{Payload: bytes.Clone(p)}, nil
|
||||
}
|
||||
|
||||
// WebRTCAnswer is sent only over DERP to deliver a WebRTC SDP answer to a peer.
|
||||
// Payload is a JSON-encoded *webrtc.SessionDescription.
|
||||
type WebRTCAnswer struct{ Payload []byte }
|
||||
|
||||
func (m *WebRTCAnswer) AppendMarshal(b []byte) []byte {
|
||||
ret, p := appendMsgHeader(b, TypeWebRTCAnswer, v0, len(m.Payload))
|
||||
copy(p, m.Payload)
|
||||
return ret
|
||||
}
|
||||
|
||||
func parseWebRTCAnswer(_ uint8, p []byte) (*WebRTCAnswer, error) {
|
||||
return &WebRTCAnswer{Payload: bytes.Clone(p)}, nil
|
||||
}
|
||||
|
||||
// WebRTCICECandidate is sent only over DERP to deliver a WebRTC ICE candidate
|
||||
// to a peer. Payload is a JSON-encoded *webrtc.ICECandidateInit.
|
||||
type WebRTCICECandidate struct{ Payload []byte }
|
||||
|
||||
func (m *WebRTCICECandidate) AppendMarshal(b []byte) []byte {
|
||||
ret, p := appendMsgHeader(b, TypeWebRTCICECandidate, v0, len(m.Payload))
|
||||
copy(p, m.Payload)
|
||||
return ret
|
||||
}
|
||||
|
||||
func parseWebRTCICECandidate(_ uint8, p []byte) (*WebRTCICECandidate, error) {
|
||||
return &WebRTCICECandidate{Payload: bytes.Clone(p)}, nil
|
||||
}
|
||||
|
||||
// MessageSummary returns a short summary of m for logging purposes.
|
||||
func MessageSummary(m Message) string {
|
||||
switch m := m.(type) {
|
||||
@ -299,6 +356,12 @@ func MessageSummary(m Message) string {
|
||||
return "allocate-udp-relay-endpoint-request"
|
||||
case *AllocateUDPRelayEndpointResponse:
|
||||
return "allocate-udp-relay-endpoint-response"
|
||||
case *WebRTCOffer:
|
||||
return "webrtc-offer"
|
||||
case *WebRTCAnswer:
|
||||
return "webrtc-answer"
|
||||
case *WebRTCICECandidate:
|
||||
return "webrtc-ice-candidate"
|
||||
default:
|
||||
return fmt.Sprintf("%#v", m)
|
||||
}
|
||||
|
||||
5
example-webrtc-server/go.mod
Normal file
5
example-webrtc-server/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module github.com/adrianosela/tailscale/example-webrtc-server
|
||||
|
||||
go 1.26
|
||||
|
||||
require github.com/gorilla/websocket v1.5.3
|
||||
2
example-webrtc-server/go.sum
Normal file
2
example-webrtc-server/go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
466
example-webrtc-server/main.go
Normal file
466
example-webrtc-server/main.go
Normal file
@ -0,0 +1,466 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// example-webrtc-server is a WebRTC signaling server that supports both
|
||||
// WebSocket (for Tailscale) and HTTP REST (for standard WebRTC clients).
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// SignalingMessage represents a WebRTC signaling message.
|
||||
// This format is compatible with both Tailscale and standard WebRTC clients.
|
||||
type SignalingMessage struct {
|
||||
Type string `json:"type"` // "offer", "answer", "candidate"
|
||||
From string `json:"from"` // sender's disco public key (hex)
|
||||
To string `json:"to"` // recipient's disco public key (hex)
|
||||
|
||||
// For SDP offer/answer (raw JSON for flexibility)
|
||||
Offer json.RawMessage `json:"offer,omitempty"`
|
||||
Answer json.RawMessage `json:"answer,omitempty"`
|
||||
Candidate json.RawMessage `json:"candidate,omitempty"`
|
||||
|
||||
// Legacy fields for HTTP REST clients
|
||||
SDP string `json:"sdp,omitempty"` // Used by non-Tailscale clients
|
||||
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Client represents a connected peer (WebSocket or HTTP polling)
|
||||
type Client struct {
|
||||
ID string
|
||||
Conn *websocket.Conn // nil for HTTP clients
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// SignalingServer manages WebRTC signaling between peers
|
||||
type SignalingServer struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*Client // Active WebSocket clients
|
||||
|
||||
// Message queue for HTTP polling clients
|
||||
messages map[string][]SignalingMessage // Key: "to" peer ID
|
||||
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
// Statistics
|
||||
stats struct {
|
||||
totalMessages int
|
||||
wsConnections int
|
||||
httpPolls int
|
||||
activeOffers int
|
||||
completedPairs int
|
||||
}
|
||||
}
|
||||
|
||||
// NewSignalingServer creates a new signaling server
|
||||
func NewSignalingServer() *SignalingServer {
|
||||
return &SignalingServer{
|
||||
clients: make(map[string]*Client),
|
||||
messages: make(map[string][]SignalingMessage),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins (configure as needed)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RouteMessage delivers a message to the destination peer
|
||||
func (s *SignalingServer) RouteMessage(msg SignalingMessage, clientIP string) error {
|
||||
msg.Timestamp = time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Update stats
|
||||
s.stats.totalMessages++
|
||||
switch msg.Type {
|
||||
case "offer":
|
||||
s.stats.activeOffers++
|
||||
// Clear old messages from this sender when starting a new session
|
||||
s.clearOldMessages(msg.From, msg.To)
|
||||
case "answer":
|
||||
s.stats.completedPairs++
|
||||
if s.stats.activeOffers > 0 {
|
||||
s.stats.activeOffers--
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[%s] Routing %s from %s to %s", clientIP, msg.Type, msg.From, msg.To)
|
||||
|
||||
// Try to deliver to WebSocket client first
|
||||
if client, ok := s.clients[msg.To]; ok && client.Conn != nil {
|
||||
// Send directly via WebSocket
|
||||
if err := client.Conn.WriteJSON(msg); err != nil {
|
||||
log.Printf("[%s] Failed to send to WebSocket client %s: %v", clientIP, msg.To, err)
|
||||
// Remove dead connection
|
||||
delete(s.clients, msg.To)
|
||||
// Fall through to queue message
|
||||
} else {
|
||||
log.Printf("[%s] Delivered %s to WebSocket client %s", clientIP, msg.Type, msg.To)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Queue for HTTP polling client
|
||||
s.messages[msg.To] = append(s.messages[msg.To], msg)
|
||||
log.Printf("[%s] Queued %s for HTTP client %s", clientIP, msg.Type, msg.To)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clearOldMessages removes previous messages between two peers (used when new session starts)
|
||||
func (s *SignalingServer) clearOldMessages(from, to string) {
|
||||
if msgs, ok := s.messages[to]; ok {
|
||||
filtered := make([]SignalingMessage, 0)
|
||||
for _, msg := range msgs {
|
||||
if msg.From != from {
|
||||
filtered = append(filtered, msg)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
delete(s.messages, to)
|
||||
} else {
|
||||
s.messages[to] = filtered
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMessages retrieves queued messages for an HTTP polling client
|
||||
func (s *SignalingServer) GetMessages(peerID, clientIP string) []SignalingMessage {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.stats.httpPolls++
|
||||
|
||||
messages := s.messages[peerID]
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return all messages and clear the queue
|
||||
delete(s.messages, peerID)
|
||||
|
||||
log.Printf("[%s] Delivering %d queued message(s) to HTTP client %s", clientIP, len(messages), peerID)
|
||||
return messages
|
||||
}
|
||||
|
||||
// CleanupOldMessages removes stale messages and dead connections
|
||||
func (s *SignalingServer) CleanupOldMessages(maxAge time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cutoff := time.Now().Add(-maxAge)
|
||||
cleaned := 0
|
||||
|
||||
// Clean old messages
|
||||
for peerID, messages := range s.messages {
|
||||
filtered := make([]SignalingMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
if msg.Timestamp.After(cutoff) {
|
||||
filtered = append(filtered, msg)
|
||||
} else {
|
||||
cleaned++
|
||||
}
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
delete(s.messages, peerID)
|
||||
} else {
|
||||
s.messages[peerID] = filtered
|
||||
}
|
||||
}
|
||||
|
||||
// Clean inactive clients
|
||||
for id, client := range s.clients {
|
||||
if time.Since(client.LastSeen) > maxAge {
|
||||
if client.Conn != nil {
|
||||
client.Conn.Close()
|
||||
}
|
||||
delete(s.clients, id)
|
||||
cleaned++
|
||||
}
|
||||
}
|
||||
|
||||
if cleaned > 0 {
|
||||
log.Printf("Cleaned up %d old messages/connections", cleaned)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns current server statistics
|
||||
func (s *SignalingServer) GetStats() map[string]interface{} {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
queuedMessages := 0
|
||||
for _, msgs := range s.messages {
|
||||
queuedMessages += len(msgs)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_messages": s.stats.totalMessages,
|
||||
"ws_connections": s.stats.wsConnections,
|
||||
"http_polls": s.stats.httpPolls,
|
||||
"active_offers": s.stats.activeOffers,
|
||||
"completed_pairs": s.stats.completedPairs,
|
||||
"queued_messages": queuedMessages,
|
||||
"active_ws_clients": len(s.clients),
|
||||
"active_peer_ids": len(s.messages),
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket Handler (for Tailscale clients)
|
||||
func (s *SignalingServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var clientID string
|
||||
|
||||
s.mu.Lock()
|
||||
s.stats.wsConnections++
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.stats.wsConnections--
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
log.Printf("[%s] New WebSocket connection", r.RemoteAddr)
|
||||
|
||||
// Read and route messages from this WebSocket client
|
||||
for {
|
||||
var msg SignalingMessage
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("[%s] WebSocket error: %v", r.RemoteAddr, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Register client on first message
|
||||
if clientID == "" {
|
||||
clientID = msg.From
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = &Client{
|
||||
ID: clientID,
|
||||
Conn: conn,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
s.mu.Unlock()
|
||||
log.Printf("[%s] WebSocket client registered as %s", r.RemoteAddr, clientID)
|
||||
}
|
||||
|
||||
// Update last seen
|
||||
s.mu.Lock()
|
||||
if client, ok := s.clients[clientID]; ok {
|
||||
client.LastSeen = time.Now()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Route the message to destination
|
||||
if err := s.RouteMessage(msg, r.RemoteAddr); err != nil {
|
||||
log.Printf("[%s] Failed to route message: %v", r.RemoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup on disconnect
|
||||
if clientID != "" {
|
||||
s.mu.Lock()
|
||||
delete(s.clients, clientID)
|
||||
s.mu.Unlock()
|
||||
log.Printf("[%s] WebSocket client %s disconnected", r.RemoteAddr, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// HTTP REST Handlers (for standard WebRTC clients)
|
||||
|
||||
func (s *SignalingServer) handlePostSignal(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var msg SignalingMessage
|
||||
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if msg.From == "" || msg.To == "" || msg.Type == "" {
|
||||
http.Error(w, "Missing required fields: from, to, type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if msg.Type != "offer" && msg.Type != "answer" && msg.Type != "candidate" {
|
||||
http.Error(w, "Invalid type, must be 'offer', 'answer', or 'candidate'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert legacy SDP field to Offer/Answer format for compatibility
|
||||
if msg.SDP != "" {
|
||||
sdpJSON := json.RawMessage(fmt.Sprintf(`{"type":"%s","sdp":%q}`, msg.Type, msg.SDP))
|
||||
if msg.Type == "offer" {
|
||||
msg.Offer = sdpJSON
|
||||
} else if msg.Type == "answer" {
|
||||
msg.Answer = sdpJSON
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.RouteMessage(msg, r.RemoteAddr); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to route message: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "ok",
|
||||
"message": fmt.Sprintf("Message routed to %s", msg.To),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SignalingServer) handleGetSignal(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
to := r.URL.Query().Get("to")
|
||||
if to == "" {
|
||||
http.Error(w, "Missing required query parameter: to", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
messages := s.GetMessages(to, r.RemoteAddr)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if len(messages) == 0 {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "not_found",
|
||||
"message": fmt.Sprintf("No messages for %s", to),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// Return first message (client should poll again for more)
|
||||
json.NewEncoder(w).Encode(messages[0])
|
||||
}
|
||||
|
||||
func (s *SignalingServer) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(s.GetStats())
|
||||
}
|
||||
|
||||
func (s *SignalingServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "healthy",
|
||||
"service": "webrtc-signaling-server",
|
||||
"version": "1.0.0",
|
||||
})
|
||||
}
|
||||
|
||||
// CORS middleware
|
||||
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Logging middleware
|
||||
func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
next(w, r)
|
||||
log.Printf("%s %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8080, "Port to listen on")
|
||||
tlsCert := flag.String("cert", "", "TLS certificate file (optional, for HTTPS)")
|
||||
tlsKey := flag.String("key", "", "TLS key file (optional, for HTTPS)")
|
||||
cleanupInterval := flag.Duration("cleanup", 5*time.Minute, "Interval for cleaning up old messages")
|
||||
messageMaxAge := flag.Duration("max-age", 10*time.Minute, "Maximum age for messages before cleanup")
|
||||
flag.Parse()
|
||||
|
||||
server := NewSignalingServer()
|
||||
|
||||
// Start cleanup goroutine
|
||||
go func() {
|
||||
ticker := time.NewTicker(*cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
server.CleanupOldMessages(*messageMaxAge)
|
||||
}
|
||||
}()
|
||||
|
||||
// Register handlers
|
||||
// WebSocket endpoint (for Tailscale)
|
||||
http.HandleFunc("/ws", loggingMiddleware(server.handleWebSocket))
|
||||
|
||||
// HTTP REST endpoints (for standard WebRTC clients)
|
||||
http.HandleFunc("/signal", corsMiddleware(loggingMiddleware(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost {
|
||||
server.handlePostSignal(w, r)
|
||||
} else if r.Method == http.MethodGet {
|
||||
server.handleGetSignal(w, r)
|
||||
} else {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})))
|
||||
|
||||
// Monitoring endpoints
|
||||
http.HandleFunc("/stats", corsMiddleware(loggingMiddleware(server.handleStats)))
|
||||
http.HandleFunc("/health", corsMiddleware(loggingMiddleware(server.handleHealth)))
|
||||
|
||||
addr := fmt.Sprintf(":%d", *port)
|
||||
|
||||
log.Printf("Server starting on %s", addr)
|
||||
log.Printf("WebSocket endpoint: ws://localhost%s/ws", addr)
|
||||
log.Printf("HTTP REST endpoint: http://localhost%s/signal", addr)
|
||||
log.Printf("Cleanup interval: %v, Max message age: %v", *cleanupInterval, *messageMaxAge)
|
||||
log.Println("────────────────────────────────────────────────────────────")
|
||||
|
||||
var err error
|
||||
if *tlsCert != "" && *tlsKey != "" {
|
||||
log.Printf("Starting HTTPS server with TLS...")
|
||||
err = http.ListenAndServeTLS(addr, *tlsCert, *tlsKey, nil)
|
||||
} else {
|
||||
log.Printf("Starting HTTP server (use -cert and -key for HTTPS)...")
|
||||
err = http.ListenAndServe(addr, nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
21
go.mod
21
go.mod
@ -57,6 +57,7 @@ require (
|
||||
github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/goreleaser/nfpm/v2 v2.33.1
|
||||
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674
|
||||
github.com/hashicorp/go-hclog v1.6.2
|
||||
github.com/hashicorp/raft v1.7.2
|
||||
github.com/hashicorp/raft-boltdb/v2 v2.3.1
|
||||
@ -79,6 +80,7 @@ require (
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/mitchellh/go-ps v1.0.0
|
||||
github.com/peterbourgon/ff/v3 v3.4.0
|
||||
github.com/pion/webrtc/v4 v4.2.11
|
||||
github.com/pires/go-proxyproto v0.8.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkg/sftp v1.13.6
|
||||
@ -120,7 +122,7 @@ require (
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.43.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/time v0.12.0
|
||||
golang.org/x/time v0.15.0
|
||||
golang.org/x/tools v0.44.0
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
@ -198,7 +200,6 @@ require (
|
||||
github.com/google/renameio/v2 v2.0.0 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
|
||||
github.com/gosuri/uitable v0.0.4 // indirect
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
@ -232,6 +233,21 @@ require (
|
||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect
|
||||
github.com/pelletier/go-toml v1.9.5 // indirect
|
||||
github.com/peterbourgon/diskv v2.0.1+incompatible // indirect
|
||||
github.com/pion/datachannel v1.6.0 // indirect
|
||||
github.com/pion/dtls/v3 v3.1.2 // indirect
|
||||
github.com/pion/ice/v4 v4.2.5 // indirect
|
||||
github.com/pion/interceptor v0.1.44 // indirect
|
||||
github.com/pion/logging v0.2.4 // indirect
|
||||
github.com/pion/mdns/v2 v2.1.0 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/rtcp v1.2.16 // indirect
|
||||
github.com/pion/rtp v1.10.1 // indirect
|
||||
github.com/pion/sctp v1.9.4 // indirect
|
||||
github.com/pion/sdp/v3 v3.0.18 // indirect
|
||||
github.com/pion/srtp/v3 v3.0.10 // indirect
|
||||
github.com/pion/stun/v3 v3.1.2 // indirect
|
||||
github.com/pion/transport/v4 v4.0.1 // indirect
|
||||
github.com/pion/turn/v5 v5.0.3 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/puzpuzpuz/xsync v1.5.2 // indirect
|
||||
github.com/rtr7/dhcp4 v0.0.0-20220302171438-18c84d089b46 // indirect
|
||||
@ -241,6 +257,7 @@ require (
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
|
||||
github.com/stacklok/frizbee v0.1.7 // indirect
|
||||
github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/xen0n/gosmopolitan v1.2.2 // indirect
|
||||
github.com/xlab/treeprint v1.2.0 // indirect
|
||||
github.com/ykadowak/zerologlint v0.1.5 // indirect
|
||||
|
||||
42
go.sum
42
go.sum
@ -942,6 +942,42 @@ github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1H
|
||||
github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
|
||||
github.com/pierrec/lz4/v4 v4.1.25 h1:kocOqRffaIbU5djlIBr7Wh+cx82C0vtFb0fOurZHqD0=
|
||||
github.com/pierrec/lz4/v4 v4.1.25/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4=
|
||||
github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0=
|
||||
github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk=
|
||||
github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc=
|
||||
github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo=
|
||||
github.com/pion/ice/v4 v4.2.5 h1:5umUQy4hX6HwMsCnJ0SX337YYCeTWDgC9JWyvUqHIHs=
|
||||
github.com/pion/ice/v4 v4.2.5/go.mod h1:aaABRaykEYnNjccjbiimuYxViaASeuv5mk9BpplUxK0=
|
||||
github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I=
|
||||
github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY=
|
||||
github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A=
|
||||
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo=
|
||||
github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo=
|
||||
github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA=
|
||||
github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM=
|
||||
github.com/pion/sctp v1.9.4 h1:cMxEu0F5tbP4qH07bKf1Zjf4rUih9LIo0qQt424e258=
|
||||
github.com/pion/sctp v1.9.4/go.mod h1:N20Dq6LY+JvJDAh9VVh1JELngb2rQ8dPgds5yBWiPgw=
|
||||
github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI=
|
||||
github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8=
|
||||
github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ=
|
||||
github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M=
|
||||
github.com/pion/stun/v3 v3.1.2 h1:86IhD8wFn6IDW4b1/0QzoQS+f5PeA8OHHRn8UZW5ErY=
|
||||
github.com/pion/stun/v3 v3.1.2/go.mod h1:H7gDic7nNwlUL05pbs6T1dtaBehh/KjupxfWw3ZI7cA=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o=
|
||||
github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM=
|
||||
github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ=
|
||||
github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ=
|
||||
github.com/pion/turn/v5 v5.0.3 h1:I+Nw0fQgdPWF1SXDj0egWDhCkcff7gWiigdQpOK52Ak=
|
||||
github.com/pion/turn/v5 v5.0.3/go.mod h1:fs4SogUh/aRGQzonc4Lx3Jp4EU3j3t0PfNDEd9KcD/w=
|
||||
github.com/pion/webrtc/v4 v4.2.11 h1:QUX1QZKlNIn4O7U5JxLPGP0sV5RTncZkzu9SPR3jVNU=
|
||||
github.com/pion/webrtc/v4 v4.2.11/go.mod h1:s/rAiyy77GyRFrZMx+Ls6aua26dIBPudH8/ZHYbIRWY=
|
||||
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
|
||||
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
|
||||
github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4=
|
||||
@ -1204,6 +1240,8 @@ github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
|
||||
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
@ -1525,8 +1563,8 @@ golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
|
||||
@ -651,7 +651,18 @@ table tbody tr:nth-child(even) td { background-color: #f5f5f5; }
|
||||
if ps.Relay != "" && ps.CurAddr == "" {
|
||||
f("relay <b>%s</b>", html.EscapeString(ps.Relay))
|
||||
} else if ps.CurAddr != "" {
|
||||
f("direct <b>%s</b>", html.EscapeString(ps.CurAddr))
|
||||
// Check if this is a WebRTC connection (magic IP 127.3.3.41)
|
||||
if strings.HasPrefix(ps.CurAddr, "127.3.3.41:") {
|
||||
// Extract the actual remote address if present
|
||||
if idx := strings.Index(ps.CurAddr, " ("); idx > 0 {
|
||||
remoteAddr := ps.CurAddr[idx+2 : len(ps.CurAddr)-1] // Extract address from " (addr)"
|
||||
f("webrtc <b>%s</b>", html.EscapeString(remoteAddr))
|
||||
} else {
|
||||
f("webrtc")
|
||||
}
|
||||
} else {
|
||||
f("direct <b>%s</b>", html.EscapeString(ps.CurAddr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
53
rtclib/signaling.go
Normal file
53
rtclib/signaling.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rtclib
|
||||
|
||||
import "github.com/pion/webrtc/v4"
|
||||
|
||||
// Signaling message types.
|
||||
const (
|
||||
MessageTypeOffer = "offer"
|
||||
MessageTypeAnswer = "answer"
|
||||
MessageTypeCandidate = "candidate"
|
||||
)
|
||||
|
||||
// SignalingMessage represents a message exchanged over the signaling channel.
|
||||
type SignalingMessage struct {
|
||||
Type string `json:"type"` // "offer", "answer", "candidate"
|
||||
From string `json:"from"` // sender's disco public key (hex)
|
||||
To string `json:"to"` // recipient's disco public key (hex)
|
||||
Offer *webrtc.SessionDescription `json:"offer,omitempty"`
|
||||
Answer *webrtc.SessionDescription `json:"answer,omitempty"`
|
||||
Candidate *webrtc.ICECandidateInit `json:"candidate,omitempty"`
|
||||
}
|
||||
|
||||
// SignalHandler defines callbacks for handling incoming signaling messages.
|
||||
type SignalHandler interface {
|
||||
// HandleOffer is called when an offer is received from a peer.
|
||||
HandleOffer(from, to string, offer *webrtc.SessionDescription)
|
||||
|
||||
// HandleAnswer is called when an answer is received from a peer.
|
||||
HandleAnswer(from, to string, answer *webrtc.SessionDescription)
|
||||
|
||||
// HandleCandidate is called when an ICE candidate is received from a peer.
|
||||
HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit)
|
||||
}
|
||||
|
||||
// Signaller defines the interface for WebRTC signaling implementations.
|
||||
type Signaller interface {
|
||||
// Start begins the signaling connection with the provided handler.
|
||||
Start(handler SignalHandler) error
|
||||
|
||||
// Offer sends an SDP offer to a peer.
|
||||
Offer(from, to string, offer *webrtc.SessionDescription) error
|
||||
|
||||
// Answer sends an SDP answer to a peer.
|
||||
Answer(from, to string, answer *webrtc.SessionDescription) error
|
||||
|
||||
// Candidate sends an ICE candidate to a peer.
|
||||
Candidate(from, to string, candidate *webrtc.ICECandidateInit) error
|
||||
|
||||
// Close shuts down the signaling connection.
|
||||
Close() error
|
||||
}
|
||||
@ -3301,6 +3301,14 @@ const DerpMagicIP = "127.3.3.40"
|
||||
|
||||
var DerpMagicIPAddr = netip.MustParseAddr(DerpMagicIP)
|
||||
|
||||
// WebRTCMagicIP is a fake WireGuard endpoint IP address that means
|
||||
// to use WebRTC data channel for packet transmission.
|
||||
//
|
||||
// Mnemonic: 127.3.3.41 is one above DerpMagicIP for WebRTC.
|
||||
const WebRTCMagicIP = "127.3.3.41"
|
||||
|
||||
var WebRTCMagicIPAddr = netip.MustParseAddr(WebRTCMagicIP)
|
||||
|
||||
// EarlyNoise is the early payload that's sent over Noise but before the HTTP/2
|
||||
// handshake when connecting to the coordination server.
|
||||
//
|
||||
|
||||
@ -66,6 +66,9 @@ var (
|
||||
// suppressing/dropping inbound/outbound [disco.Ping] messages, forcing
|
||||
// all peer communication over DERP or peer relay.
|
||||
debugNeverDirectUDP = envknob.RegisterBool("TS_DEBUG_NEVER_DIRECT_UDP")
|
||||
// debugAlwaysWebRTC forces all peer communication over WebRTC by
|
||||
// suppressing disco pings to direct UDP and DERP addresses.
|
||||
debugAlwaysWebRTC = envknob.RegisterBool("TS_DEBUG_ALWAYS_USE_WEBRTC")
|
||||
// Hey you! Adding a new debugknob? Make sure to stub it out in the
|
||||
// debugknobs_stubs.go file too.
|
||||
)
|
||||
|
||||
@ -32,3 +32,4 @@ func inTest() bool { return false }
|
||||
func debugPeerMap() bool { return false }
|
||||
func pretendpoints() []netip.AddrPort { return []netip.AddrPort{} }
|
||||
func debugNeverDirectUDP() bool { return false }
|
||||
func debugAlwaysWebRTC() bool { return false }
|
||||
|
||||
@ -1076,7 +1076,13 @@ func (de *endpoint) send(buffs [][]byte, offset int) error {
|
||||
}
|
||||
}
|
||||
var err error
|
||||
if udpAddr.ap.IsValid() {
|
||||
// Check if this is a WebRTC address and route accordingly
|
||||
if udpAddr.ap.IsValid() && udpAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr {
|
||||
// Pack all buffs into one SCTP message. See sendWebRTCBatch for why.
|
||||
if err = de.c.sendWebRTCBatch(de.publicKey, buffs, offset); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if udpAddr.ap.IsValid() {
|
||||
_, err = de.c.sendUDPBatch(udpAddr, buffs, offset)
|
||||
|
||||
// If the error is known to indicate that the endpoint is no longer
|
||||
@ -1116,6 +1122,10 @@ func (de *endpoint) send(buffs [][]byte, offset int) error {
|
||||
}
|
||||
}
|
||||
if derpAddr.IsValid() {
|
||||
// Traffic is flowing via DERP; opportunistically upgrade to WebRTC.
|
||||
if mgr := de.c.webrtcMgr; mgr != nil {
|
||||
mgr.ensureConnecting(de)
|
||||
}
|
||||
allOk := true
|
||||
var txBytes int
|
||||
for _, buff := range buffs {
|
||||
@ -1294,6 +1304,9 @@ func (de *endpoint) startDiscoPingLocked(ep epAddr, now mono.Time, purpose disco
|
||||
if debugNeverDirectUDP() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.DerpMagicIPAddr {
|
||||
return
|
||||
}
|
||||
if debugAlwaysWebRTC() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.WebRTCMagicIPAddr {
|
||||
return
|
||||
}
|
||||
epDisco := de.disco.Load()
|
||||
if epDisco == nil {
|
||||
return
|
||||
@ -1810,10 +1823,13 @@ type epAddr struct {
|
||||
vni packet.VirtualNetworkID // vni.IsSet() indicates if this [epAddr] involves a Geneve header
|
||||
}
|
||||
|
||||
// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr,
|
||||
// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr or WebRTCMagicIPAddr,
|
||||
// and a VNI is not set.
|
||||
func (e epAddr) isDirect() bool {
|
||||
return e.ap.IsValid() && e.ap.Addr() != tailcfg.DerpMagicIPAddr && !e.vni.IsSet()
|
||||
return e.ap.IsValid() &&
|
||||
e.ap.Addr() != tailcfg.DerpMagicIPAddr &&
|
||||
e.ap.Addr() != tailcfg.WebRTCMagicIPAddr &&
|
||||
!e.vni.IsSet()
|
||||
}
|
||||
|
||||
func (e epAddr) String() string {
|
||||
@ -1866,6 +1882,28 @@ func betterAddr(a, b addrQuality) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// WebRTC path priority: Direct UDP > WebRTC > Peer Relay/DERP
|
||||
aIsWebRTC := a.ap.Addr() == tailcfg.WebRTCMagicIPAddr
|
||||
bIsWebRTC := b.ap.Addr() == tailcfg.WebRTCMagicIPAddr
|
||||
aIsDERP := a.ap.Addr() == tailcfg.DerpMagicIPAddr
|
||||
bIsDERP := b.ap.Addr() == tailcfg.DerpMagicIPAddr
|
||||
|
||||
// Direct paths beat WebRTC
|
||||
if a.isDirect() && bIsWebRTC {
|
||||
return true
|
||||
}
|
||||
if b.isDirect() && aIsWebRTC {
|
||||
return false
|
||||
}
|
||||
|
||||
// WebRTC beats DERP and relay (VNI)
|
||||
if aIsWebRTC && (bIsDERP || b.vni.IsSet()) {
|
||||
return true
|
||||
}
|
||||
if bIsWebRTC && (aIsDERP || a.vni.IsSet()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Each address starts with a set of points (from 0 to 100) that
|
||||
// represents how much faster they are than the highest-latency
|
||||
// endpoint. For example, if a has latency 200ms and b has latency
|
||||
@ -1886,19 +1924,26 @@ func betterAddr(a, b addrQuality) bool {
|
||||
// addresses, and prefer link-local unicast addresses over other types
|
||||
// of private IP addresses since it's definitionally more likely that
|
||||
// they'll be on the same network segment than a general private IP.
|
||||
if a.ap.Addr().IsLoopback() {
|
||||
aPoints += 50
|
||||
} else if a.ap.Addr().IsLinkLocalUnicast() {
|
||||
aPoints += 30
|
||||
} else if a.ap.Addr().IsPrivate() {
|
||||
aPoints += 20
|
||||
//
|
||||
// Exclude magic IPs (DERP, WebRTC) from these bonuses as they're not
|
||||
// real network paths.
|
||||
if !aIsDERP && !aIsWebRTC {
|
||||
if a.ap.Addr().IsLoopback() {
|
||||
aPoints += 50
|
||||
} else if a.ap.Addr().IsLinkLocalUnicast() {
|
||||
aPoints += 30
|
||||
} else if a.ap.Addr().IsPrivate() {
|
||||
aPoints += 20
|
||||
}
|
||||
}
|
||||
if b.ap.Addr().IsLoopback() {
|
||||
bPoints += 50
|
||||
} else if b.ap.Addr().IsLinkLocalUnicast() {
|
||||
bPoints += 30
|
||||
} else if b.ap.Addr().IsPrivate() {
|
||||
bPoints += 20
|
||||
if !bIsDERP && !bIsWebRTC {
|
||||
if b.ap.Addr().IsLoopback() {
|
||||
bPoints += 50
|
||||
} else if b.ap.Addr().IsLinkLocalUnicast() {
|
||||
bPoints += 30
|
||||
} else if b.ap.Addr().IsPrivate() {
|
||||
bPoints += 20
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer IPv6 for being a bit more robust, as long as
|
||||
@ -2015,9 +2060,8 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) {
|
||||
de.mu.Lock()
|
||||
defer de.mu.Unlock()
|
||||
|
||||
ps.Relay = de.c.derpRegionCodeOfIDLocked(int(de.derpAddr.Port()))
|
||||
|
||||
if de.lastSendExt.IsZero() {
|
||||
ps.Relay = de.c.derpRegionCodeOfIDLocked(int(de.derpAddr.Port()))
|
||||
return
|
||||
}
|
||||
|
||||
@ -2030,7 +2074,18 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) {
|
||||
ps.PeerRelay = udpAddr.String()
|
||||
} else {
|
||||
ps.CurAddr = udpAddr.String()
|
||||
// If this is a WebRTC connection, append the actual remote address
|
||||
if udpAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr && de.c.webrtcMgr != nil {
|
||||
if disco := de.disco.Load(); disco != nil {
|
||||
if remoteAddr := de.c.webrtcMgr.getRemoteAddr(disco.key); remoteAddr.IsValid() {
|
||||
ps.CurAddr = fmt.Sprintf("%s (%s)", ps.CurAddr, remoteAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Not on a direct or WebRTC path; show the DERP relay being used.
|
||||
ps.Relay = de.c.derpRegionCodeOfIDLocked(int(de.derpAddr.Port()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -25,6 +26,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
"go4.org/mem"
|
||||
@ -94,6 +96,7 @@ type Path string
|
||||
const (
|
||||
PathDirectIPv4 Path = "direct_ipv4"
|
||||
PathDirectIPv6 Path = "direct_ipv6"
|
||||
PathWebRTC Path = "webrtc"
|
||||
PathDERP Path = "derp"
|
||||
PathPeerRelayIPv4 Path = "peer_relay_ipv4"
|
||||
PathPeerRelayIPv6 Path = "peer_relay_ipv6"
|
||||
@ -109,6 +112,14 @@ type pathLabel struct {
|
||||
Path Path
|
||||
}
|
||||
|
||||
// webrtcReadResult is the result of reading a packet from a WebRTC data channel.
|
||||
// It is similar to derpReadResult but for WebRTC connections.
|
||||
type webrtcReadResult struct {
|
||||
n int // length of data in buf
|
||||
src key.NodePublic // sender's node public key
|
||||
buf []byte // packet data; nil signals the receiver to ignore this message
|
||||
}
|
||||
|
||||
// metrics in wgengine contains the usermetrics counters for magicsock, it
|
||||
// is however a bit special. All them metrics are labeled, but looking up
|
||||
// the metric everytime we need to record it has an overhead, and includes
|
||||
@ -119,6 +130,7 @@ type metrics struct {
|
||||
// labeled by the path the packet took.
|
||||
inboundPacketsIPv4Total expvar.Int
|
||||
inboundPacketsIPv6Total expvar.Int
|
||||
inboundPacketsWebRTCTotal expvar.Int
|
||||
inboundPacketsDERPTotal expvar.Int
|
||||
inboundPacketsPeerRelayIPv4Total expvar.Int
|
||||
inboundPacketsPeerRelayIPv6Total expvar.Int
|
||||
@ -127,6 +139,7 @@ type metrics struct {
|
||||
// labeled by the path the packet took.
|
||||
inboundBytesIPv4Total expvar.Int
|
||||
inboundBytesIPv6Total expvar.Int
|
||||
inboundBytesWebRTCTotal expvar.Int
|
||||
inboundBytesDERPTotal expvar.Int
|
||||
inboundBytesPeerRelayIPv4Total expvar.Int
|
||||
inboundBytesPeerRelayIPv6Total expvar.Int
|
||||
@ -135,6 +148,7 @@ type metrics struct {
|
||||
// labeled by the path the packet took.
|
||||
outboundPacketsIPv4Total expvar.Int
|
||||
outboundPacketsIPv6Total expvar.Int
|
||||
outboundPacketsWebRTCTotal expvar.Int
|
||||
outboundPacketsDERPTotal expvar.Int
|
||||
outboundPacketsPeerRelayIPv4Total expvar.Int
|
||||
outboundPacketsPeerRelayIPv6Total expvar.Int
|
||||
@ -143,6 +157,7 @@ type metrics struct {
|
||||
// labeled by the path the packet took.
|
||||
outboundBytesIPv4Total expvar.Int
|
||||
outboundBytesIPv6Total expvar.Int
|
||||
outboundBytesWebRTCTotal expvar.Int
|
||||
outboundBytesDERPTotal expvar.Int
|
||||
outboundBytesPeerRelayIPv4Total expvar.Int
|
||||
outboundBytesPeerRelayIPv6Total expvar.Int
|
||||
@ -211,6 +226,10 @@ type Conn struct {
|
||||
// It must have buffer size > 0; see issue 3736.
|
||||
derpRecvCh chan derpReadResult
|
||||
|
||||
// webrtcRecvCh is used by receiveWebRTC to read WebRTC messages.
|
||||
// It must have buffer size > 0, similar to derpRecvCh.
|
||||
webrtcRecvCh chan webrtcReadResult
|
||||
|
||||
// bind is the wireguard-go conn.Bind for Conn.
|
||||
bind *connBind
|
||||
|
||||
@ -337,6 +356,10 @@ type Conn struct {
|
||||
// [tailscale.com/net/udprelay.Server] endpoints.
|
||||
relayManager relayManager
|
||||
|
||||
// webrtcMgr manages WebRTC connections for peers.
|
||||
// May be nil if WebRTC failed to initialize.
|
||||
webrtcMgr *webrtcManager
|
||||
|
||||
// discoInfo is the state for an active peer DiscoKey.
|
||||
discoInfo map[key.DiscoPublic]*discoInfo
|
||||
|
||||
@ -569,7 +592,8 @@ func newConn(logf logger.Logf) *Conn {
|
||||
discoPrivate := key.NewDisco()
|
||||
c := &Conn{
|
||||
logf: logf,
|
||||
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
|
||||
derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736
|
||||
webrtcRecvCh: make(chan webrtcReadResult, 64), // must be buffered, similar to derpRecvCh
|
||||
derpStarted: make(chan struct{}),
|
||||
peerLastDerp: make(map[key.NodePublic]int),
|
||||
peerMap: newPeerMap(),
|
||||
@ -721,6 +745,14 @@ func NewConn(opts Options) (*Conn, error) {
|
||||
}
|
||||
|
||||
c.logf("magicsock: disco key = %v", c.discoAtomic.Short())
|
||||
|
||||
// Initialize WebRTC manager with disco-based signaling.
|
||||
c.logf("magicsock: initializing WebRTC with disco signaling")
|
||||
c.webrtcMgr = newWebRTCManager(c)
|
||||
if c.webrtcMgr == nil {
|
||||
c.logf("magicsock: failed to initialize WebRTC manager")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@ -730,6 +762,7 @@ func NewConn(opts Options) (*Conn, error) {
|
||||
func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
pathDirectV4 := pathLabel{Path: PathDirectIPv4}
|
||||
pathDirectV6 := pathLabel{Path: PathDirectIPv6}
|
||||
pathWebRTC := pathLabel{Path: PathWebRTC}
|
||||
pathDERP := pathLabel{Path: PathDERP}
|
||||
pathPeerRelayV4 := pathLabel{Path: PathPeerRelayIPv4}
|
||||
pathPeerRelayV6 := pathLabel{Path: PathPeerRelayIPv6}
|
||||
@ -764,21 +797,25 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
// Map clientmetrics to the usermetric counters.
|
||||
metricRecvDataPacketsIPv4.Register(&m.inboundPacketsIPv4Total)
|
||||
metricRecvDataPacketsIPv6.Register(&m.inboundPacketsIPv6Total)
|
||||
metricRecvDataPacketsWebRTC.Register(&m.inboundPacketsWebRTCTotal)
|
||||
metricRecvDataPacketsDERP.Register(&m.inboundPacketsDERPTotal)
|
||||
metricRecvDataPacketsPeerRelayIPv4.Register(&m.inboundPacketsPeerRelayIPv4Total)
|
||||
metricRecvDataPacketsPeerRelayIPv6.Register(&m.inboundPacketsPeerRelayIPv6Total)
|
||||
metricRecvDataBytesIPv4.Register(&m.inboundBytesIPv4Total)
|
||||
metricRecvDataBytesIPv6.Register(&m.inboundBytesIPv6Total)
|
||||
metricRecvDataBytesWebRTC.Register(&m.inboundBytesWebRTCTotal)
|
||||
metricRecvDataBytesDERP.Register(&m.inboundBytesDERPTotal)
|
||||
metricRecvDataBytesPeerRelayIPv4.Register(&m.inboundBytesPeerRelayIPv4Total)
|
||||
metricRecvDataBytesPeerRelayIPv6.Register(&m.inboundBytesPeerRelayIPv6Total)
|
||||
metricSendDataPacketsIPv4.Register(&m.outboundPacketsIPv4Total)
|
||||
metricSendDataPacketsIPv6.Register(&m.outboundPacketsIPv6Total)
|
||||
metricSendDataPacketsWebRTC.Register(&m.outboundPacketsWebRTCTotal)
|
||||
metricSendDataPacketsDERP.Register(&m.outboundPacketsDERPTotal)
|
||||
metricSendDataPacketsPeerRelayIPv4.Register(&m.outboundPacketsPeerRelayIPv4Total)
|
||||
metricSendDataPacketsPeerRelayIPv6.Register(&m.outboundPacketsPeerRelayIPv6Total)
|
||||
metricSendDataBytesIPv4.Register(&m.outboundBytesIPv4Total)
|
||||
metricSendDataBytesIPv6.Register(&m.outboundBytesIPv6Total)
|
||||
metricSendDataBytesWebRTC.Register(&m.outboundBytesWebRTCTotal)
|
||||
metricSendDataBytesDERP.Register(&m.outboundBytesDERPTotal)
|
||||
metricSendDataBytesPeerRelayIPv4.Register(&m.outboundBytesPeerRelayIPv4Total)
|
||||
metricSendDataBytesPeerRelayIPv6.Register(&m.outboundBytesPeerRelayIPv6Total)
|
||||
@ -790,24 +827,28 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
|
||||
inboundPacketsTotal.Set(pathDirectV4, &m.inboundPacketsIPv4Total)
|
||||
inboundPacketsTotal.Set(pathDirectV6, &m.inboundPacketsIPv6Total)
|
||||
inboundPacketsTotal.Set(pathWebRTC, &m.inboundPacketsWebRTCTotal)
|
||||
inboundPacketsTotal.Set(pathDERP, &m.inboundPacketsDERPTotal)
|
||||
inboundPacketsTotal.Set(pathPeerRelayV4, &m.inboundPacketsPeerRelayIPv4Total)
|
||||
inboundPacketsTotal.Set(pathPeerRelayV6, &m.inboundPacketsPeerRelayIPv6Total)
|
||||
|
||||
inboundBytesTotal.Set(pathDirectV4, &m.inboundBytesIPv4Total)
|
||||
inboundBytesTotal.Set(pathDirectV6, &m.inboundBytesIPv6Total)
|
||||
inboundBytesTotal.Set(pathWebRTC, &m.inboundBytesWebRTCTotal)
|
||||
inboundBytesTotal.Set(pathDERP, &m.inboundBytesDERPTotal)
|
||||
inboundBytesTotal.Set(pathPeerRelayV4, &m.inboundBytesPeerRelayIPv4Total)
|
||||
inboundBytesTotal.Set(pathPeerRelayV6, &m.inboundBytesPeerRelayIPv6Total)
|
||||
|
||||
outboundPacketsTotal.Set(pathDirectV4, &m.outboundPacketsIPv4Total)
|
||||
outboundPacketsTotal.Set(pathDirectV6, &m.outboundPacketsIPv6Total)
|
||||
outboundPacketsTotal.Set(pathWebRTC, &m.outboundPacketsWebRTCTotal)
|
||||
outboundPacketsTotal.Set(pathDERP, &m.outboundPacketsDERPTotal)
|
||||
outboundPacketsTotal.Set(pathPeerRelayV4, &m.outboundPacketsPeerRelayIPv4Total)
|
||||
outboundPacketsTotal.Set(pathPeerRelayV6, &m.outboundPacketsPeerRelayIPv6Total)
|
||||
|
||||
outboundBytesTotal.Set(pathDirectV4, &m.outboundBytesIPv4Total)
|
||||
outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total)
|
||||
outboundBytesTotal.Set(pathWebRTC, &m.outboundBytesWebRTCTotal)
|
||||
outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal)
|
||||
outboundBytesTotal.Set(pathPeerRelayV4, &m.outboundBytesPeerRelayIPv4Total)
|
||||
outboundBytesTotal.Set(pathPeerRelayV6, &m.outboundBytesPeerRelayIPv6Total)
|
||||
@ -822,21 +863,25 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
|
||||
func deregisterMetrics() {
|
||||
metricRecvDataPacketsIPv4.UnregisterAll()
|
||||
metricRecvDataPacketsIPv6.UnregisterAll()
|
||||
metricRecvDataPacketsWebRTC.UnregisterAll()
|
||||
metricRecvDataPacketsDERP.UnregisterAll()
|
||||
metricRecvDataPacketsPeerRelayIPv4.UnregisterAll()
|
||||
metricRecvDataPacketsPeerRelayIPv6.UnregisterAll()
|
||||
metricRecvDataBytesIPv4.UnregisterAll()
|
||||
metricRecvDataBytesIPv6.UnregisterAll()
|
||||
metricRecvDataBytesWebRTC.UnregisterAll()
|
||||
metricRecvDataBytesDERP.UnregisterAll()
|
||||
metricRecvDataBytesPeerRelayIPv4.UnregisterAll()
|
||||
metricRecvDataBytesPeerRelayIPv6.UnregisterAll()
|
||||
metricSendDataPacketsIPv4.UnregisterAll()
|
||||
metricSendDataPacketsIPv6.UnregisterAll()
|
||||
metricSendDataPacketsWebRTC.UnregisterAll()
|
||||
metricSendDataPacketsDERP.UnregisterAll()
|
||||
metricSendDataPacketsPeerRelayIPv4.UnregisterAll()
|
||||
metricSendDataPacketsPeerRelayIPv6.UnregisterAll()
|
||||
metricSendDataBytesIPv4.UnregisterAll()
|
||||
metricSendDataBytesIPv6.UnregisterAll()
|
||||
metricSendDataBytesWebRTC.UnregisterAll()
|
||||
metricSendDataBytesDERP.UnregisterAll()
|
||||
metricSendDataBytesPeerRelayIPv4.UnregisterAll()
|
||||
metricSendDataBytesPeerRelayIPv6.UnregisterAll()
|
||||
@ -1624,6 +1669,9 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error)
|
||||
// IPv6 address when the local machine doesn't have IPv6 support
|
||||
// returns (false, nil); it's not an error, but nothing was sent.
|
||||
func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool, isGeneveEncap bool) (sent bool, err error) {
|
||||
if addr.Addr() == tailcfg.WebRTCMagicIPAddr {
|
||||
return c.sendWebRTC(addr, pubKey, b)
|
||||
}
|
||||
if addr.Addr() != tailcfg.DerpMagicIPAddr {
|
||||
return c.sendUDP(addr, b, isDisco, isGeneveEncap)
|
||||
}
|
||||
@ -1665,6 +1713,170 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, is
|
||||
return false, errDropDerpPacket
|
||||
}
|
||||
|
||||
// webrtcBatchMagic is the first byte of a batched WebRTC SCTP message.
|
||||
// WireGuard packets start with 0x01–0x04, disco packets start with 0x54 ('T'),
|
||||
// so 0xBA is unambiguous.
|
||||
const webrtcBatchMagic = byte(0xBA)
|
||||
|
||||
// sendWebRTCBatch packs all buffs into a single SCTP message and sends it.
|
||||
// Batching is the critical throughput optimization: the per-packet WebRTC
|
||||
// path calls rwc.Write once per WireGuard packet, producing one SCTP message,
|
||||
// one DTLS record, and one UDP send per packet. Packing N packets into one
|
||||
// SCTP message reduces that to a single write — the same advantage
|
||||
// sendUDPBatch (sendmmsg) gives the regular UDP path.
|
||||
//
|
||||
// Wire format for N>1: [0xBA magic][2-byte BE len][packet]...[2-byte BE len][packet]
|
||||
// Single packet: sent as-is with no framing overhead.
|
||||
func (c *Conn) sendWebRTCBatch(pubKey key.NodePublic, buffs [][]byte, offset int) error {
|
||||
if c.webrtcMgr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve endpoint and disco key once for the whole batch.
|
||||
c.mu.Lock()
|
||||
ep, ok := c.peerMap.endpointForNodeKey(pubKey)
|
||||
c.mu.Unlock()
|
||||
if !ok || ep == nil {
|
||||
return nil
|
||||
}
|
||||
disco := ep.disco.Load()
|
||||
if disco == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(buffs) == 1 {
|
||||
// Fast path: single packet, no framing overhead.
|
||||
b := buffs[0][offset:]
|
||||
if err := c.webrtcMgr.sendPacket(disco.key, b); err != nil {
|
||||
return err
|
||||
}
|
||||
c.metrics.outboundPacketsWebRTCTotal.Add(1)
|
||||
c.metrics.outboundBytesWebRTCTotal.Add(int64(len(b)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Multi-packet batch path.
|
||||
size := 1 // magic byte
|
||||
for _, b := range buffs {
|
||||
size += 2 + len(b[offset:])
|
||||
}
|
||||
batch := make([]byte, size)
|
||||
batch[0] = webrtcBatchMagic
|
||||
pos := 1
|
||||
var totalBytes int64
|
||||
for _, b := range buffs {
|
||||
pkt := b[offset:]
|
||||
binary.BigEndian.PutUint16(batch[pos:], uint16(len(pkt)))
|
||||
pos += 2
|
||||
copy(batch[pos:], pkt)
|
||||
pos += len(pkt)
|
||||
totalBytes += int64(len(pkt))
|
||||
}
|
||||
if err := c.webrtcMgr.sendPacket(disco.key, batch); err != nil {
|
||||
return err
|
||||
}
|
||||
c.metrics.outboundPacketsWebRTCTotal.Add(int64(len(buffs)))
|
||||
c.metrics.outboundBytesWebRTCTotal.Add(totalBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendWebRTC sends a packet over WebRTC data channel.
|
||||
func (c *Conn) sendWebRTC(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) {
|
||||
if c.webrtcMgr == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Find the endpoint by public key
|
||||
c.mu.Lock()
|
||||
ep, ok := c.peerMap.endpointForNodeKey(pubKey)
|
||||
c.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Get the disco key for this endpoint
|
||||
disco := ep.disco.Load()
|
||||
if disco == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Send via WebRTC manager
|
||||
if err := c.webrtcMgr.sendPacket(disco.key, b); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
c.metrics.outboundPacketsWebRTCTotal.Add(1)
|
||||
c.metrics.outboundBytesWebRTCTotal.Add(int64(len(b)))
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// receiveWebRTC handles packets received from WebRTC data channels.
|
||||
// This is called by webrtcManager when data arrives on a data channel.
|
||||
// receiveWebRTC is called by webrtcManager when data arrives on a WebRTC data channel.
|
||||
// It queues the packet for processing by wireguard-go through the webrtcRecvCh channel.
|
||||
func (c *Conn) receiveWebRTC(b []byte, srcNodeKey key.NodePublic) {
|
||||
// Copy into a fresh slice: b belongs to the reader goroutine's reusable
|
||||
// buffer which will be overwritten on the next Read call.
|
||||
pkt := make([]byte, len(b))
|
||||
copy(pkt, b)
|
||||
|
||||
select {
|
||||
case c.webrtcRecvCh <- webrtcReadResult{n: len(pkt), src: srcNodeKey, buf: pkt}:
|
||||
case <-c.connCtx.Done():
|
||||
default:
|
||||
c.logf("webrtc: dropped packet from %v, receive channel full", srcNodeKey.ShortString())
|
||||
}
|
||||
}
|
||||
|
||||
// processWebRTCReadResult processes a WebRTC packet received from the webrtcRecvCh.
|
||||
// It's similar to processDERPReadResult but for WebRTC packets.
|
||||
func (c *Conn) processWebRTCReadResult(wr webrtcReadResult, b []byte) (n int, ep *endpoint) {
|
||||
if wr.buf == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n = wr.n
|
||||
ncopy := copy(b, wr.buf[:n])
|
||||
if ncopy != n {
|
||||
err := fmt.Errorf("received WebRTC packet of length %d that's too big for WireGuard buf size %d", n, ncopy)
|
||||
c.logf("magicsock: %v", err)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
srcAddr := epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}
|
||||
|
||||
// Check if this looks like a disco packet
|
||||
pt, isGeneveEncap := packetLooksLike(b[:n])
|
||||
if pt == packetLooksLikeDisco && !isGeneveEncap {
|
||||
c.handleDiscoMessage(b[:n], srcAddr, false, wr.src, discoRXPathWebRTC)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Find the endpoint by node key
|
||||
var ok bool
|
||||
c.mu.Lock()
|
||||
ep, ok = c.peerMap.endpointForNodeKey(wr.src)
|
||||
c.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
// We don't know anything about this node key
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
ep.noteRecvActivity(srcAddr, mono.Now())
|
||||
if update := c.connCounter.Load(); update != nil {
|
||||
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, n, true)
|
||||
}
|
||||
|
||||
c.metrics.inboundPacketsWebRTCTotal.Add(1)
|
||||
c.metrics.inboundBytesWebRTCTotal.Add(int64(n))
|
||||
|
||||
return n, ep
|
||||
}
|
||||
|
||||
type receiveBatch struct {
|
||||
msgs []ipv6.Message
|
||||
}
|
||||
@ -1999,7 +2211,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.
|
||||
if !dstKey.IsZero() {
|
||||
node = dstKey.ShortString()
|
||||
}
|
||||
c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt))
|
||||
c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, pathStr(dst.String()), disco.MessageSummary(m), len(pkt))
|
||||
}
|
||||
if isDERP {
|
||||
metricSentDiscoDERP.Add(1)
|
||||
@ -2039,6 +2251,7 @@ type discoRXPath string
|
||||
const (
|
||||
discoRXPathUDP discoRXPath = "UDP socket"
|
||||
discoRXPathDERP discoRXPath = "DERP"
|
||||
discoRXPathWebRTC discoRXPath = "WebRTC"
|
||||
discoRXPathRawSocket discoRXPath = "raw socket"
|
||||
)
|
||||
|
||||
@ -2350,13 +2563,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
|
||||
if isVia {
|
||||
c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints",
|
||||
c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(),
|
||||
ep.publicKey.ShortString(), derpStr(src.String()),
|
||||
ep.publicKey.ShortString(), pathStr(src.String()),
|
||||
len(via.AddrPorts))
|
||||
c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via)
|
||||
} else {
|
||||
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
|
||||
c.discoAtomic.Short(), epDisco.short,
|
||||
ep.publicKey.ShortString(), derpStr(src.String()),
|
||||
ep.publicKey.ShortString(), pathStr(src.String()),
|
||||
len(cmm.MyNumber))
|
||||
go ep.handleCallMeMaybe(cmm)
|
||||
}
|
||||
@ -2402,7 +2615,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
|
||||
if isResp {
|
||||
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints",
|
||||
c.discoAtomic.Short(), epDisco.short,
|
||||
ep.publicKey.ShortString(), derpStr(src.String()),
|
||||
ep.publicKey.ShortString(), pathStr(src.String()),
|
||||
msgType,
|
||||
len(resp.AddrPorts))
|
||||
c.relayManager.handleRxDiscoMsg(c, resp, nodeKey, di.discoKey, src)
|
||||
@ -2416,7 +2629,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
|
||||
} else {
|
||||
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v",
|
||||
c.discoAtomic.Short(), epDisco.short,
|
||||
ep.publicKey.ShortString(), derpStr(src.String()),
|
||||
ep.publicKey.ShortString(), pathStr(src.String()),
|
||||
msgType,
|
||||
req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString())
|
||||
}
|
||||
@ -2440,6 +2653,39 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
|
||||
RxFromNodeKey: nodeKey,
|
||||
Message: req,
|
||||
})
|
||||
case *disco.WebRTCOffer, *disco.WebRTCAnswer, *disco.WebRTCICECandidate:
|
||||
if !isDERP {
|
||||
c.logf("[unexpected] WebRTC signaling message received via UDP, expected DERP only")
|
||||
return
|
||||
}
|
||||
if c.webrtcMgr == nil {
|
||||
return
|
||||
}
|
||||
// Dispatch to the webrtcManager off the hot path; c.mu must not be held.
|
||||
senderStr := sender.String()
|
||||
switch dm := dm.(type) {
|
||||
case *disco.WebRTCOffer:
|
||||
var sdp webrtc.SessionDescription
|
||||
if err := json.Unmarshal(dm.Payload, &sdp); err != nil {
|
||||
c.logf("webrtc: disco: failed to unmarshal offer from %v: %v", sender.ShortString(), err)
|
||||
return
|
||||
}
|
||||
go c.webrtcMgr.HandleOffer(senderStr, "", &sdp)
|
||||
case *disco.WebRTCAnswer:
|
||||
var sdp webrtc.SessionDescription
|
||||
if err := json.Unmarshal(dm.Payload, &sdp); err != nil {
|
||||
c.logf("webrtc: disco: failed to unmarshal answer from %v: %v", sender.ShortString(), err)
|
||||
return
|
||||
}
|
||||
go c.webrtcMgr.HandleAnswer(senderStr, "", &sdp)
|
||||
case *disco.WebRTCICECandidate:
|
||||
var candidate webrtc.ICECandidateInit
|
||||
if err := json.Unmarshal(dm.Payload, &candidate); err != nil {
|
||||
c.logf("webrtc: disco: failed to unmarshal ICE candidate from %v: %v", sender.ShortString(), err)
|
||||
return
|
||||
}
|
||||
go c.webrtcMgr.HandleCandidate(senderStr, "", &candidate)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -3389,9 +3635,9 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error)
|
||||
return nil, 0, errors.New("magicsock: connBind already open")
|
||||
}
|
||||
c.closed = false
|
||||
fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP}
|
||||
fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP, c.receiveWebRTC}
|
||||
if runtime.GOOS == "js" {
|
||||
fns = []conn.ReceiveFunc{c.receiveDERP}
|
||||
fns = []conn.ReceiveFunc{c.receiveDERP, c.receiveWebRTC}
|
||||
}
|
||||
// TODO: Combine receiveIPv4 and receiveIPv6 and receiveIP into a single
|
||||
// closure that closes over a *RebindingUDPConn?
|
||||
@ -3468,6 +3714,14 @@ func (c *Conn) Close() error {
|
||||
ep.stopAndReset()
|
||||
})
|
||||
|
||||
// Close WebRTC manager if initialized
|
||||
if c.webrtcMgr != nil {
|
||||
c.webrtcMgr.close()
|
||||
c.webrtcMgr = nil
|
||||
}
|
||||
|
||||
close(c.webrtcRecvCh)
|
||||
|
||||
c.closed = true
|
||||
c.connCtxCancel()
|
||||
c.closeAllDerpLocked("conn-close")
|
||||
@ -4022,9 +4276,20 @@ func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) {
|
||||
}
|
||||
}
|
||||
|
||||
// pathStr formats endpoint addresses for display, replacing magic IPs with readable names.
|
||||
// It replaces DERP IPs with "derp-" and WebRTC IPs with "webrtc-".
|
||||
func pathStr(s string) string {
|
||||
s = derpStr(s)
|
||||
s = webrtcStr(s)
|
||||
return s
|
||||
}
|
||||
|
||||
// derpStr replaces DERP IPs in s with "derp-".
|
||||
func derpStr(s string) string { return strings.ReplaceAll(s, "127.3.3.40:", "derp-") }
|
||||
|
||||
// webrtcStr replaces WebRTC IPs in s with "webrtc-".
|
||||
func webrtcStr(s string) string { return strings.ReplaceAll(s, "127.3.3.41:", "webrtc-") }
|
||||
|
||||
// epAddrEndpointCache is a mutex-free single-element cache, mapping from
|
||||
// a single [epAddr] to a single [*endpoint].
|
||||
type epAddrEndpointCache struct {
|
||||
@ -4105,11 +4370,13 @@ var (
|
||||
metricRecvDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_derp")
|
||||
metricRecvDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv4")
|
||||
metricRecvDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv6")
|
||||
metricRecvDataPacketsWebRTC = clientmetric.NewAggregateCounter("magicsock_recv_data_webrtc")
|
||||
metricRecvDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv4")
|
||||
metricRecvDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv6")
|
||||
metricSendDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_send_data_derp")
|
||||
metricSendDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv4")
|
||||
metricSendDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv6")
|
||||
metricSendDataPacketsWebRTC = clientmetric.NewAggregateCounter("magicsock_send_data_webrtc")
|
||||
metricSendDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv4")
|
||||
metricSendDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv6")
|
||||
|
||||
@ -4117,11 +4384,13 @@ var (
|
||||
metricRecvDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_derp")
|
||||
metricRecvDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv4")
|
||||
metricRecvDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv6")
|
||||
metricRecvDataBytesWebRTC = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_webrtc")
|
||||
metricRecvDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv4")
|
||||
metricRecvDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv6")
|
||||
metricSendDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_derp")
|
||||
metricSendDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv4")
|
||||
metricSendDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv6")
|
||||
metricSendDataBytesWebRTC = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_webrtc")
|
||||
metricSendDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv4")
|
||||
metricSendDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv6")
|
||||
|
||||
@ -4241,6 +4510,16 @@ func (c *Conn) SetLastNetcheckReportForTest(ctx context.Context, report *netchec
|
||||
c.lastNetCheckReport.Store(report)
|
||||
}
|
||||
|
||||
// findEndpointByDisco returns the first endpoint with the given disco key, or nil if not found.
|
||||
func (c *Conn) findEndpointByDisco(dk key.DiscoPublic) *endpoint {
|
||||
var found *endpoint
|
||||
c.peerMap.forEachEndpointWithDiscoKey(dk, func(ep *endpoint) bool {
|
||||
found = ep
|
||||
return false // stop after first match
|
||||
})
|
||||
return found
|
||||
}
|
||||
|
||||
// lazyEndpoint is a wireguard [conn.Endpoint] for when magicsock received a
|
||||
// non-disco (presumably WireGuard) packet from a UDP address from which we
|
||||
// can't map to a Tailscale peer. But WireGuard most likely can, once it
|
||||
|
||||
925
wgengine/magicsock/webrtc.go
Normal file
925
wgengine/magicsock/webrtc.go
Normal file
@ -0,0 +1,925 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"github.com/tailscale/wireguard-go/conn"
|
||||
"tailscale.com/rtclib"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstime/mono"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// webrtcConnState represents the state of a WebRTC connection.
|
||||
type webrtcConnState int
|
||||
|
||||
const (
|
||||
webrtcStateIdle webrtcConnState = iota
|
||||
webrtcStateConnecting
|
||||
webrtcStateConnected
|
||||
webrtcStateFailed
|
||||
webrtcStateClosed
|
||||
)
|
||||
|
||||
// dataChannelRW is the detached io.ReadWriteCloser for a WebRTC DataChannel.
|
||||
// It is stored via atomic.Pointer so the hot send path can retrieve it without
|
||||
// holding the webrtcManager mutex.
|
||||
type dataChannelRW struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// webrtcPeerState tracks WebRTC connection state for a single peer.
|
||||
type webrtcPeerState struct {
|
||||
ep *endpoint
|
||||
peerConn *webrtc.PeerConnection
|
||||
dataChannel *webrtc.DataChannel
|
||||
dcRW atomic.Pointer[dataChannelRW] // non-nil once the DataChannel is open
|
||||
localDisco key.DiscoPublic
|
||||
remoteDisco key.DiscoPublic
|
||||
remoteNodeKey key.NodePublic // peer's node public key (for WireGuard)
|
||||
remoteAddr netip.AddrPort // actual remote address from ICE candidate
|
||||
state webrtcConnState
|
||||
lastError error
|
||||
createdAt time.Time
|
||||
|
||||
// remoteDescSet is true once SetRemoteDescription has been called.
|
||||
// ICE candidates that arrive before that point are held in
|
||||
// pendingCandidates and applied immediately after. Both fields are
|
||||
// protected by webrtcManager.mu.
|
||||
remoteDescSet bool
|
||||
pendingCandidates []webrtc.ICECandidateInit
|
||||
}
|
||||
|
||||
// webrtcConnectionReadyEvent signals that a WebRTC connection is ready.
|
||||
type webrtcConnectionReadyEvent struct {
|
||||
remoteDisco key.DiscoPublic
|
||||
ep *endpoint
|
||||
}
|
||||
|
||||
// webrtcManager manages WebRTC connections for magicsock.
|
||||
type webrtcManager struct {
|
||||
logf logger.Logf
|
||||
conn *Conn // parent magicsock.Conn
|
||||
|
||||
mu sync.RWMutex
|
||||
peerConnectionsByEndpoint map[*endpoint]*webrtcPeerState
|
||||
peerConnectionsByDisco map[key.DiscoPublic]*webrtcPeerState
|
||||
|
||||
signaller rtclib.Signaller
|
||||
|
||||
// Control channels
|
||||
startConnectionCh chan *endpoint
|
||||
connectionReadyCh chan webrtcConnectionReadyEvent
|
||||
closeCh chan struct{}
|
||||
runLoopStoppedCh chan struct{}
|
||||
|
||||
// WebRTC API configuration
|
||||
api *webrtc.API
|
||||
}
|
||||
|
||||
// Ensure webrtcManager implements rtclib.SignalHandler interface.
|
||||
var _ rtclib.SignalHandler = (*webrtcManager)(nil)
|
||||
|
||||
// newWebRTCManager creates a new WebRTC manager using disco-based signaling.
|
||||
func newWebRTCManager(c *Conn) *webrtcManager {
|
||||
mgr := newWebRTCManagerBase(c)
|
||||
|
||||
mgr.signaller = &discoSignaller{conn: c}
|
||||
if err := mgr.signaller.Start(mgr); err != nil {
|
||||
c.logf("webrtc: failed to start signaller: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
go mgr.runLoop()
|
||||
|
||||
return mgr
|
||||
}
|
||||
|
||||
// close shuts down the WebRTC manager.
|
||||
func (m *webrtcManager) close() error {
|
||||
// Close signaller first to stop new messages
|
||||
if m.signaller != nil {
|
||||
if err := m.signaller.Close(); err != nil {
|
||||
m.logf("webrtc: signaller close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Signal runLoop to stop
|
||||
close(m.closeCh)
|
||||
|
||||
// Wait for runLoop to finish with timeout
|
||||
select {
|
||||
case <-m.runLoopStoppedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
m.logf("webrtc: close timed out, forcing shutdown")
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Close all peer connections
|
||||
for _, ps := range m.peerConnectionsByEndpoint {
|
||||
if ps.peerConn != nil {
|
||||
ps.peerConn.Close()
|
||||
}
|
||||
}
|
||||
m.peerConnectionsByEndpoint = nil
|
||||
m.peerConnectionsByDisco = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startConnection initiates a WebRTC connection to an endpoint.
|
||||
func (m *webrtcManager) startConnection(ep *endpoint) {
|
||||
if debugAlwaysDERP() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case m.startConnectionCh <- ep:
|
||||
case <-m.closeCh:
|
||||
default:
|
||||
m.logf("webrtc: startConnection queue full for %v", ep.nodeAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureConnecting triggers a WebRTC connection to ep if one is not already
|
||||
// in progress or established. It also retries connections in terminal states
|
||||
// (Failed, Closed). It is safe to call from the hot send path.
|
||||
func (m *webrtcManager) ensureConnecting(ep *endpoint) {
|
||||
m.mu.RLock()
|
||||
ps, exists := m.peerConnectionsByEndpoint[ep]
|
||||
m.mu.RUnlock()
|
||||
if !exists || ps.state == webrtcStateFailed || ps.state == webrtcStateClosed {
|
||||
m.startConnection(ep)
|
||||
}
|
||||
}
|
||||
|
||||
// deliverWebRTCMsg delivers one DataChannel message to the receive pipeline.
|
||||
// It handles both single packets and batches (webrtcBatchMagic framing) so
|
||||
// the logic is shared between the native detached-reader path and the
|
||||
// JS/fallback OnMessage callback path.
|
||||
func (m *webrtcManager) deliverWebRTCMsg(ps *webrtcPeerState, data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
// Batch: [0xBA magic][2-byte BE len][pkt]...[2-byte BE len][pkt]
|
||||
if data[0] == webrtcBatchMagic {
|
||||
data = data[1:]
|
||||
for len(data) >= 2 {
|
||||
pktLen := int(binary.BigEndian.Uint16(data))
|
||||
data = data[2:]
|
||||
if pktLen > len(data) {
|
||||
m.logf("webrtc: batch framing error for peer %v: pktLen %d > remaining %d",
|
||||
ps.remoteDisco.ShortString(), pktLen, len(data))
|
||||
return
|
||||
}
|
||||
m.conn.receiveWebRTC(data[:pktLen], ps.remoteNodeKey)
|
||||
data = data[pktLen:]
|
||||
}
|
||||
return
|
||||
}
|
||||
m.conn.receiveWebRTC(data, ps.remoteNodeKey)
|
||||
}
|
||||
|
||||
// runDataChannelReader is the per-peer receive loop used when DetachDataChannels
|
||||
// is enabled (native builds). It reads directly from the detached io.ReadWriteCloser
|
||||
// into a reused buffer, avoiding the per-message goroutine wakeup and allocation
|
||||
// that the OnMessage callback path incurs.
|
||||
func (m *webrtcManager) runDataChannelReader(ps *webrtcPeerState, rwc io.ReadWriteCloser) {
|
||||
// Size the buffer to hold the largest possible batch.
|
||||
// 64 WireGuard packets × ~1420 bytes + framing < 100 KiB; 256 KiB is safe.
|
||||
buf := make([]byte, 256*1024)
|
||||
for {
|
||||
n, err := rwc.Read(buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, net.ErrClosed) {
|
||||
m.logf("webrtc: data channel read error for peer %v: %v", ps.remoteDisco.ShortString(), err)
|
||||
}
|
||||
ps.dcRW.Store(nil)
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
m.deliverWebRTCMsg(ps, buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRemoteAddr returns the actual remote address for a WebRTC peer connection.
|
||||
func (m *webrtcManager) getRemoteAddr(disco key.DiscoPublic) netip.AddrPort {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ps, ok := m.peerConnectionsByDisco[disco]; ok && ps.state == webrtcStateConnected {
|
||||
return ps.remoteAddr
|
||||
}
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
// markRemoteDescSet marks ps as having a remote description set and flushes
|
||||
// any ICE candidates that arrived before SetRemoteDescription was called.
|
||||
// Must be called after SetRemoteDescription succeeds, without holding m.mu.
|
||||
func (m *webrtcManager) markRemoteDescSet(ps *webrtcPeerState) {
|
||||
m.mu.Lock()
|
||||
ps.remoteDescSet = true
|
||||
pending := ps.pendingCandidates
|
||||
ps.pendingCandidates = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
for i := range pending {
|
||||
if err := ps.peerConn.AddICECandidate(pending[i]); err != nil {
|
||||
m.logf("webrtc: failed to add buffered ICE candidate for peer %v: %v",
|
||||
ps.remoteDisco.ShortString(), err)
|
||||
}
|
||||
}
|
||||
if len(pending) > 0 {
|
||||
m.logf("webrtc: flushed %d buffered ICE candidates for peer %v",
|
||||
len(pending), ps.remoteDisco.ShortString())
|
||||
}
|
||||
}
|
||||
|
||||
// runLoop is the main event loop for the WebRTC manager.
|
||||
func (m *webrtcManager) runLoop() {
|
||||
defer close(m.runLoopStoppedCh)
|
||||
|
||||
retryTicker := time.NewTicker(15 * time.Second)
|
||||
defer retryTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ep := <-m.startConnectionCh:
|
||||
m.handleStartConnection(ep)
|
||||
|
||||
case event := <-m.connectionReadyCh:
|
||||
m.handleConnectionReady(event)
|
||||
|
||||
case <-retryTicker.C:
|
||||
m.retryFailedConnections()
|
||||
|
||||
case <-m.closeCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// retryFailedConnections re-queues any connections in a terminal state so they
|
||||
// get a fresh attempt. This covers cases where both peers restart simultaneously
|
||||
// and the initial attempt fails before DERP is established.
|
||||
func (m *webrtcManager) retryFailedConnections() {
|
||||
m.mu.RLock()
|
||||
var toRetry []*endpoint
|
||||
for ep, ps := range m.peerConnectionsByEndpoint {
|
||||
if ps.state == webrtcStateFailed || ps.state == webrtcStateClosed {
|
||||
toRetry = append(toRetry, ep)
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
for _, ep := range toRetry {
|
||||
m.logf("webrtc: retrying failed connection to peer %v", ep.nodeAddr)
|
||||
m.startConnection(ep)
|
||||
}
|
||||
}
|
||||
|
||||
// handleStartConnection creates a new WebRTC connection to an endpoint.
|
||||
func (m *webrtcManager) handleStartConnection(ep *endpoint) {
|
||||
m.mu.Lock()
|
||||
|
||||
// Check if we already have a connection
|
||||
if ps, exists := m.peerConnectionsByEndpoint[ep]; exists {
|
||||
switch ps.state {
|
||||
case webrtcStateConnecting, webrtcStateConnected:
|
||||
m.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
// Terminal state (Failed, Closed): close the old connection and
|
||||
// remove it from the maps so we can create a fresh one below.
|
||||
ps.peerConn.Close()
|
||||
delete(m.peerConnectionsByEndpoint, ep)
|
||||
delete(m.peerConnectionsByDisco, ps.remoteDisco)
|
||||
}
|
||||
}
|
||||
|
||||
// Get disco keys
|
||||
localDisco := m.conn.DiscoPublicKey()
|
||||
disco := ep.disco.Load()
|
||||
if disco == nil {
|
||||
m.mu.Unlock()
|
||||
m.logf("webrtc: cannot start connection, peer has no disco key")
|
||||
return
|
||||
}
|
||||
remoteDisco := disco.key
|
||||
|
||||
// Check that the peer's DERP address is known before proceeding.
|
||||
// If it isn't, the signaling offer will fail immediately. This can
|
||||
// happen on startup or after a disco-key rotation before the DERP
|
||||
// connection to the new key is established. The next netmap update
|
||||
// will re-trigger startConnection once the peer is reachable.
|
||||
ep.mu.Lock()
|
||||
derpReady := ep.derpAddr.IsValid()
|
||||
ep.mu.Unlock()
|
||||
if !derpReady {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
m.logf("webrtc: starting connection to peer %v (disco %v)", ep.nodeAddr, remoteDisco.ShortString())
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// Create peer connection
|
||||
config := webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
},
|
||||
},
|
||||
ICETransportPolicy: webrtc.ICETransportPolicyAll,
|
||||
}
|
||||
|
||||
peerConn, err := m.api.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
m.logf("webrtc: failed to create peer connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ps := &webrtcPeerState{
|
||||
ep: ep,
|
||||
peerConn: peerConn,
|
||||
localDisco: localDisco,
|
||||
remoteDisco: remoteDisco,
|
||||
remoteNodeKey: ep.publicKey,
|
||||
state: webrtcStateConnecting,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store peer state
|
||||
m.mu.Lock()
|
||||
m.peerConnectionsByEndpoint[ep] = ps
|
||||
m.peerConnectionsByDisco[remoteDisco] = ps
|
||||
m.mu.Unlock()
|
||||
|
||||
// Set up connection state handler
|
||||
peerConn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
m.handleConnectionStateChange(ps, state)
|
||||
})
|
||||
|
||||
// Set up ICE candidate handler
|
||||
peerConn.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate != nil {
|
||||
m.handleLocalICECandidate(ps, candidate)
|
||||
}
|
||||
})
|
||||
|
||||
// Create an unordered, unreliable data channel (MaxRetransmits=0).
|
||||
// WireGuard is designed to run over raw UDP, which is unordered and
|
||||
// unreliable. Using an ordered/reliable DataChannel (the default) wraps
|
||||
// WireGuard in SCTP's reliable-ordered-stream semantics, causing
|
||||
// head-of-line blocking whenever a packet is lost: SCTP holds back all
|
||||
// subsequent packets until the missing one is retransmitted and delivered
|
||||
// in order. That is why throughput over WebRTC was worse than DERP.
|
||||
// Setting Ordered=false and MaxRetransmits=0 makes the DataChannel behave
|
||||
// like a UDP socket, which is exactly what WireGuard expects.
|
||||
unordered := false
|
||||
maxRetransmits := uint16(0)
|
||||
dataChannel, err := peerConn.CreateDataChannel("tailscale-wg", &webrtc.DataChannelInit{
|
||||
Ordered: &unordered,
|
||||
MaxRetransmits: &maxRetransmits,
|
||||
})
|
||||
if err != nil {
|
||||
m.logf("webrtc: failed to create data channel: %v", err)
|
||||
peerConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
ps.dataChannel = dataChannel
|
||||
|
||||
// Set up data channel handlers.
|
||||
// With DetachDataChannels enabled, OnMessage cannot be used. Instead we
|
||||
// call Detach() inside OnOpen to get a raw io.ReadWriteCloser and spin
|
||||
// up a dedicated reader goroutine, which eliminates per-packet callback
|
||||
// overhead and goroutine wakeups.
|
||||
setOnError(dataChannel, func(err error) {
|
||||
m.logf("webrtc: data channel error for peer %v: %v", remoteDisco.ShortString(), err)
|
||||
})
|
||||
|
||||
dataChannel.OnOpen(func() {
|
||||
// Native: DetachDataChannels was enabled; get a raw io.ReadWriteCloser
|
||||
// and spin a dedicated reader goroutine (zero per-message allocations).
|
||||
// JS/fallback: Detach() returns an error; fall back to OnMessage
|
||||
// callbacks, which is the only API available in the browser.
|
||||
if rwc, err := dataChannel.Detach(); err == nil {
|
||||
ps.dcRW.Store(&dataChannelRW{rwc})
|
||||
go m.runDataChannelReader(ps, rwc)
|
||||
} else {
|
||||
dataChannel.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
m.deliverWebRTCMsg(ps, msg.Data)
|
||||
})
|
||||
}
|
||||
m.logf("webrtc: data channel opened for peer %v", remoteDisco.ShortString())
|
||||
m.connectionReadyCh <- webrtcConnectionReadyEvent{
|
||||
remoteDisco: remoteDisco,
|
||||
ep: ep,
|
||||
}
|
||||
})
|
||||
|
||||
// Create and send offer
|
||||
offer, err := peerConn.CreateOffer(nil)
|
||||
if err != nil {
|
||||
m.logf("webrtc: failed to create offer: %v", err)
|
||||
peerConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if err := peerConn.SetLocalDescription(offer); err != nil {
|
||||
m.logf("webrtc: failed to set local description: %v", err)
|
||||
peerConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Send offer via signaling
|
||||
if err := m.signaller.Offer(localDisco.String(), remoteDisco.String(), &offer); err != nil {
|
||||
m.logf("webrtc: failed to send offer: %v", err)
|
||||
peerConn.Close()
|
||||
m.mu.Lock()
|
||||
delete(m.peerConnectionsByEndpoint, ep)
|
||||
delete(m.peerConnectionsByDisco, remoteDisco)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
m.logf("webrtc: sent offer to peer %v", remoteDisco.ShortString())
|
||||
}
|
||||
|
||||
// HandleOffer implements rtclib.SignalHandler.
|
||||
func (m *webrtcManager) HandleOffer(from, to string, offer *webrtc.SessionDescription) {
|
||||
m.logf("webrtc: received offer from=%s", from)
|
||||
|
||||
var remoteDisco key.DiscoPublic
|
||||
if err := remoteDisco.UnmarshalText([]byte(from)); err != nil {
|
||||
m.logf("webrtc: invalid sender disco key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.handleRemoteOffer(remoteDisco, offer)
|
||||
}
|
||||
|
||||
// HandleAnswer implements rtclib.SignalHandler.
|
||||
func (m *webrtcManager) HandleAnswer(from, to string, answer *webrtc.SessionDescription) {
|
||||
m.logf("webrtc: received answer from=%s", from)
|
||||
|
||||
var remoteDisco key.DiscoPublic
|
||||
if err := remoteDisco.UnmarshalText([]byte(from)); err != nil {
|
||||
m.logf("webrtc: invalid sender disco key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.handleRemoteAnswer(remoteDisco, answer)
|
||||
}
|
||||
|
||||
// HandleCandidate implements rtclib.SignalHandler.
|
||||
func (m *webrtcManager) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) {
|
||||
m.logf("webrtc: received candidate from=%s", from)
|
||||
|
||||
var remoteDisco key.DiscoPublic
|
||||
if err := remoteDisco.UnmarshalText([]byte(from)); err != nil {
|
||||
m.logf("webrtc: invalid sender disco key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.handleRemoteCandidate(remoteDisco, candidate)
|
||||
}
|
||||
|
||||
// handleRemoteOffer processes an incoming offer from a peer.
|
||||
func (m *webrtcManager) handleRemoteOffer(remoteDisco key.DiscoPublic, offer *webrtc.SessionDescription) {
|
||||
|
||||
// For incoming connections, we need to find the endpoint by disco key
|
||||
m.mu.Lock()
|
||||
ps, exists := m.peerConnectionsByDisco[remoteDisco]
|
||||
m.mu.Unlock()
|
||||
|
||||
if exists {
|
||||
switch ps.peerConn.SignalingState() {
|
||||
case webrtc.SignalingStateHaveLocalOffer:
|
||||
// Glare: both sides sent offers simultaneously. Tiebreak by disco key:
|
||||
// the peer with the lexicographically smaller local key wins and keeps
|
||||
// its offer; the loser rolls back and answers the remote offer instead.
|
||||
localDisco := m.conn.DiscoPublicKey()
|
||||
if localDisco.Compare(remoteDisco) < 0 {
|
||||
// We win — ignore their offer; they will roll back and answer ours.
|
||||
m.logf("webrtc: glare with peer %v: ignoring their offer (we win tiebreak)", remoteDisco.ShortString())
|
||||
return
|
||||
}
|
||||
// We lose — roll back our offer and fall through to answer theirs.
|
||||
m.logf("webrtc: glare with peer %v: rolling back our offer (we lose tiebreak)", remoteDisco.ShortString())
|
||||
if err := ps.peerConn.SetLocalDescription(webrtc.SessionDescription{Type: webrtc.SDPTypeRollback}); err != nil {
|
||||
m.logf("webrtc: glare rollback failed: %v; closing and recreating", err)
|
||||
ps.peerConn.Close()
|
||||
m.mu.Lock()
|
||||
delete(m.peerConnectionsByEndpoint, ps.ep)
|
||||
delete(m.peerConnectionsByDisco, remoteDisco)
|
||||
m.mu.Unlock()
|
||||
exists = false
|
||||
}
|
||||
case webrtc.SignalingStateStable:
|
||||
if ps.state == webrtcStateConnected || ps.state == webrtcStateConnecting {
|
||||
// The connection is already working or in progress. Ignore the
|
||||
// peer's offer — they will notice their connection succeeded too
|
||||
// and stop retrying.
|
||||
m.logf("webrtc: ignoring offer from %v, already have %v connection", remoteDisco.ShortString(), ps.state)
|
||||
return
|
||||
}
|
||||
// Stable but in a terminal state (Failed/Closed): the peer is trying
|
||||
// to reconnect. Tear down our stale entry and answer fresh below.
|
||||
m.logf("webrtc: tearing down stale %v connection to %v, answering fresh offer", ps.state, remoteDisco.ShortString())
|
||||
ps.peerConn.Close()
|
||||
m.mu.Lock()
|
||||
delete(m.peerConnectionsByEndpoint, ps.ep)
|
||||
delete(m.peerConnectionsByDisco, remoteDisco)
|
||||
m.mu.Unlock()
|
||||
exists = false
|
||||
default:
|
||||
// Any other transitional signaling state — ignore, let it settle.
|
||||
m.logf("webrtc: ignoring offer from %v in unexpected signaling state %v", remoteDisco.ShortString(), ps.peerConn.SignalingState())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !exists {
|
||||
// We received an offer but don't have a connection yet.
|
||||
// Find the endpoint by disco key and create peer connection state.
|
||||
ep := m.conn.findEndpointByDisco(remoteDisco)
|
||||
if ep == nil {
|
||||
m.logf("webrtc: received offer from unknown peer %v with no endpoint", remoteDisco.ShortString())
|
||||
return
|
||||
}
|
||||
|
||||
m.logf("webrtc: received offer from peer %v, creating answerer connection", remoteDisco.ShortString())
|
||||
|
||||
// Create peer connection for incoming offer
|
||||
config := webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
},
|
||||
},
|
||||
ICETransportPolicy: webrtc.ICETransportPolicyAll,
|
||||
}
|
||||
|
||||
peerConn, err := m.api.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
m.logf("webrtc: failed to create peer connection for incoming offer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
localDisco := m.conn.DiscoPublicKey()
|
||||
ps = &webrtcPeerState{
|
||||
ep: ep,
|
||||
peerConn: peerConn,
|
||||
localDisco: localDisco,
|
||||
remoteDisco: remoteDisco,
|
||||
remoteNodeKey: ep.publicKey,
|
||||
state: webrtcStateConnecting,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
// Store peer state
|
||||
m.mu.Lock()
|
||||
m.peerConnectionsByEndpoint[ep] = ps
|
||||
m.peerConnectionsByDisco[remoteDisco] = ps
|
||||
m.mu.Unlock()
|
||||
|
||||
// Set up connection state handler
|
||||
peerConn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
|
||||
m.handleConnectionStateChange(ps, state)
|
||||
})
|
||||
|
||||
// Set up ICE candidate handler
|
||||
peerConn.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate != nil {
|
||||
m.handleLocalICECandidate(ps, candidate)
|
||||
}
|
||||
})
|
||||
|
||||
// Set up data channel handler (for answerer, we wait for the data channel from offerer).
|
||||
peerConn.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||
m.logf("webrtc: received data channel from peer %v", remoteDisco.ShortString())
|
||||
ps.dataChannel = dc
|
||||
|
||||
setOnError(dc, func(err error) {
|
||||
m.logf("webrtc: data channel error for peer %v: %v", remoteDisco.ShortString(), err)
|
||||
})
|
||||
|
||||
dc.OnOpen(func() {
|
||||
if rwc, err := dc.Detach(); err == nil {
|
||||
ps.dcRW.Store(&dataChannelRW{rwc})
|
||||
go m.runDataChannelReader(ps, rwc)
|
||||
} else {
|
||||
dc.OnMessage(func(msg webrtc.DataChannelMessage) {
|
||||
m.deliverWebRTCMsg(ps, msg.Data)
|
||||
})
|
||||
}
|
||||
m.logf("webrtc: data channel opened for peer %v", remoteDisco.ShortString())
|
||||
m.connectionReadyCh <- webrtcConnectionReadyEvent{
|
||||
remoteDisco: remoteDisco,
|
||||
ep: ep,
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if err := ps.peerConn.SetRemoteDescription(*offer); err != nil {
|
||||
m.logf("webrtc: failed to set remote description: %v", err)
|
||||
return
|
||||
}
|
||||
m.markRemoteDescSet(ps)
|
||||
|
||||
// Create answer
|
||||
answer, err := ps.peerConn.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
m.logf("webrtc: failed to create answer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := ps.peerConn.SetLocalDescription(answer); err != nil {
|
||||
m.logf("webrtc: failed to set local description: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send answer via signaling
|
||||
if err := m.signaller.Answer(ps.localDisco.String(), remoteDisco.String(), &answer); err != nil {
|
||||
m.logf("webrtc: failed to send answer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.logf("webrtc: sent answer to peer %v", remoteDisco.ShortString())
|
||||
}
|
||||
|
||||
// handleRemoteAnswer processes an incoming answer from a peer.
|
||||
func (m *webrtcManager) handleRemoteAnswer(remoteDisco key.DiscoPublic, answer *webrtc.SessionDescription) {
|
||||
m.mu.Lock()
|
||||
ps, exists := m.peerConnectionsByDisco[remoteDisco]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
m.logf("webrtc: received answer from unknown peer %v", remoteDisco.ShortString())
|
||||
return
|
||||
}
|
||||
|
||||
if err := ps.peerConn.SetRemoteDescription(*answer); err != nil {
|
||||
m.logf("webrtc: failed to set remote description: %v", err)
|
||||
return
|
||||
}
|
||||
m.markRemoteDescSet(ps)
|
||||
|
||||
m.logf("webrtc: set remote description for peer %v", remoteDisco.ShortString())
|
||||
}
|
||||
|
||||
// handleRemoteCandidate processes an incoming ICE candidate from a peer.
|
||||
func (m *webrtcManager) handleRemoteCandidate(remoteDisco key.DiscoPublic, candidate *webrtc.ICECandidateInit) {
|
||||
m.mu.Lock()
|
||||
ps, exists := m.peerConnectionsByDisco[remoteDisco]
|
||||
if exists && !ps.remoteDescSet {
|
||||
// Remote description not set yet — buffer the candidate and apply it
|
||||
// once SetRemoteDescription is called (see markRemoteDescSet).
|
||||
if candidate.Candidate != "" {
|
||||
if addr := parseICECandidateAddr(candidate.Candidate); addr.IsValid() {
|
||||
ps.remoteAddr = addr
|
||||
}
|
||||
}
|
||||
ps.pendingCandidates = append(ps.pendingCandidates, *candidate)
|
||||
m.mu.Unlock()
|
||||
m.logf("webrtc: buffered ICE candidate for peer %v (remote desc not yet set)", remoteDisco.ShortString())
|
||||
return
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
m.logf("webrtc: received candidate from unknown peer %v", remoteDisco.ShortString())
|
||||
return
|
||||
}
|
||||
|
||||
// Try to extract the remote address from the candidate string
|
||||
// Candidate format: "candidate:... udp ... <ip> <port> typ ..."
|
||||
if candidate.Candidate != "" {
|
||||
if addr := parseICECandidateAddr(candidate.Candidate); addr.IsValid() {
|
||||
m.mu.Lock()
|
||||
ps.remoteAddr = addr
|
||||
m.mu.Unlock()
|
||||
m.logf("webrtc: peer %v candidate address: %v", remoteDisco.ShortString(), addr)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ps.peerConn.AddICECandidate(*candidate); err != nil {
|
||||
m.logf("webrtc: failed to add ICE candidate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.logf("webrtc: added ICE candidate for peer %v", remoteDisco.ShortString())
|
||||
}
|
||||
|
||||
// parseICECandidateAddr extracts the IP:port from an ICE candidate SDP string.
|
||||
// Example candidate: "candidate:1234 1 udp 2130706431 192.168.1.100 54321 typ host"
|
||||
func parseICECandidateAddr(candidate string) netip.AddrPort {
|
||||
fields := strings.Fields(candidate)
|
||||
// Format: candidate:<foundation> <component> <protocol> <priority> <ip> <port> typ <type>
|
||||
if len(fields) < 7 {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
ip := fields[4]
|
||||
port := fields[5]
|
||||
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
var portNum uint16
|
||||
if _, err := fmt.Sscanf(port, "%d", &portNum); err != nil {
|
||||
return netip.AddrPort{}
|
||||
}
|
||||
|
||||
return netip.AddrPortFrom(addr, portNum)
|
||||
}
|
||||
|
||||
// handleLocalICECandidate sends a local ICE candidate to a peer via signaling.
|
||||
func (m *webrtcManager) handleLocalICECandidate(ps *webrtcPeerState, candidate *webrtc.ICECandidate) {
|
||||
candidateInit := candidate.ToJSON()
|
||||
if err := m.signaller.Candidate(ps.localDisco.String(), ps.remoteDisco.String(), &candidateInit); err != nil {
|
||||
m.logf("webrtc: failed to send candidate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.logf("webrtc: sent ICE candidate to peer %v", ps.remoteDisco.ShortString())
|
||||
}
|
||||
|
||||
// handleConnectionStateChange handles WebRTC connection state changes.
|
||||
func (m *webrtcManager) handleConnectionStateChange(ps *webrtcPeerState, state webrtc.PeerConnectionState) {
|
||||
m.logf("webrtc: connection state changed to %s for peer %v", state.String(), ps.remoteDisco.ShortString())
|
||||
|
||||
m.mu.Lock()
|
||||
|
||||
var clearBestAddr bool
|
||||
switch state {
|
||||
case webrtc.PeerConnectionStateConnected:
|
||||
ps.state = webrtcStateConnected
|
||||
// Log the selected ICE candidate pair so we can confirm the actual
|
||||
// data path (LAN host candidate vs. STUN server-reflexive vs. relay).
|
||||
go func() {
|
||||
cp, err := ps.peerConn.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
|
||||
if err != nil || cp == nil {
|
||||
m.logf("webrtc: peer %v connected (selected candidate pair unavailable: %v)",
|
||||
ps.remoteDisco.ShortString(), err)
|
||||
return
|
||||
}
|
||||
m.logf("webrtc: peer %v connected via %s:%d → %s:%d (local %s, remote %s)",
|
||||
ps.remoteDisco.ShortString(),
|
||||
cp.Local.Address, cp.Local.Port,
|
||||
cp.Remote.Address, cp.Remote.Port,
|
||||
cp.Local.Typ, cp.Remote.Typ)
|
||||
}()
|
||||
case webrtc.PeerConnectionStateFailed:
|
||||
ps.state = webrtcStateFailed
|
||||
ps.lastError = errors.New("connection failed")
|
||||
ps.dcRW.Store(nil)
|
||||
clearBestAddr = true
|
||||
case webrtc.PeerConnectionStateClosed:
|
||||
ps.state = webrtcStateClosed
|
||||
ps.dcRW.Store(nil)
|
||||
clearBestAddr = true
|
||||
case webrtc.PeerConnectionStateDisconnected:
|
||||
// Transient state — do not clear bestAddr yet; the connection may recover.
|
||||
}
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// clearWebRTCBestAddr acquires ep.mu; must be called without m.mu held.
|
||||
if clearBestAddr {
|
||||
m.clearWebRTCBestAddr(ps)
|
||||
}
|
||||
}
|
||||
|
||||
// clearWebRTCBestAddr resets the endpoint's bestAddr if it is currently the
|
||||
// WebRTC magic address, so that traffic immediately falls back to DERP.
|
||||
// Must be called without holding m.mu or ep.mu.
|
||||
func (m *webrtcManager) clearWebRTCBestAddr(ps *webrtcPeerState) {
|
||||
ps.ep.mu.Lock()
|
||||
defer ps.ep.mu.Unlock()
|
||||
if ps.ep.bestAddr.ap.Addr() == tailcfg.WebRTCMagicIPAddr {
|
||||
ps.ep.bestAddr = addrQuality{}
|
||||
ps.ep.trustBestAddrUntil = 0
|
||||
m.logf("webrtc: cleared WebRTC bestAddr for peer %v, falling back to DERP", ps.remoteDisco.ShortString())
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnectionReady marks a WebRTC connection as ready and updates endpoint.
|
||||
func (m *webrtcManager) handleConnectionReady(event webrtcConnectionReadyEvent) {
|
||||
m.logf("webrtc: connection ready for peer %v", event.remoteDisco.ShortString())
|
||||
if debugAlwaysDERP() {
|
||||
return
|
||||
}
|
||||
|
||||
// Update endpoint to use WebRTC path
|
||||
event.ep.mu.Lock()
|
||||
defer event.ep.mu.Unlock()
|
||||
|
||||
// Use a fixed port number for WebRTC connections (similar to DERP)
|
||||
// The magic IP identifies this as WebRTC, not UDP
|
||||
webrtcAddr := addrQuality{
|
||||
epAddr: epAddr{
|
||||
ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345),
|
||||
},
|
||||
latency: 0, // Will be determined by disco pings, same as DERP
|
||||
}
|
||||
|
||||
// Set as bestAddr if better than current
|
||||
now := mono.Now()
|
||||
if betterAddr(webrtcAddr, event.ep.bestAddr) {
|
||||
event.ep.bestAddr = webrtcAddr
|
||||
event.ep.bestAddrAt = now
|
||||
event.ep.trustBestAddrUntil = now.Add(5 * time.Minute)
|
||||
m.logf("webrtc: updated endpoint %v with WebRTC path", event.ep.nodeAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// sendPacket sends a packet over a WebRTC data channel.
|
||||
// The hot path is lock-free: we take a read-lock (not write-lock) to look up
|
||||
// the peer state, then do an atomic load for the detached channel. Multiple
|
||||
// concurrent senders for different peers never contend.
|
||||
func (m *webrtcManager) sendPacket(disco key.DiscoPublic, b []byte) error {
|
||||
m.mu.RLock()
|
||||
ps, ok := m.peerConnectionsByDisco[disco]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return errors.New("no WebRTC connection")
|
||||
}
|
||||
|
||||
// Native path: DetachDataChannels was enabled; use the raw io.ReadWriteCloser.
|
||||
if rw := ps.dcRW.Load(); rw != nil {
|
||||
if _, err := rw.Write(b); err != nil {
|
||||
return fmt.Errorf("send failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// JS/fallback path: use DataChannel.Send() directly.
|
||||
dc := ps.dataChannel
|
||||
if dc == nil || dc.ReadyState() != webrtc.DataChannelStateOpen {
|
||||
return errors.New("data channel not ready")
|
||||
}
|
||||
return dc.Send(b)
|
||||
}
|
||||
|
||||
// receiveWebRTC reads packets from the WebRTC receive channel.
|
||||
// It is called by wireguard-go through the conn.Bind interface.
|
||||
// It blocks until at least one packet is available, then drains as many
|
||||
// additional packets as are immediately ready (up to len(buffs)).
|
||||
func (c *connBind) receiveWebRTC(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||
// Block until the first packet arrives (or the channel is closed).
|
||||
wr, ok := <-c.webrtcRecvCh
|
||||
if !ok || c.isClosed() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
num := 0
|
||||
n, ep := c.processWebRTCReadResult(wr, buffs[num])
|
||||
if n > 0 {
|
||||
sizes[num] = n
|
||||
eps[num] = ep
|
||||
num++
|
||||
}
|
||||
// Drain any additional packets that are immediately available.
|
||||
for num < len(buffs) {
|
||||
select {
|
||||
case wr, ok = <-c.webrtcRecvCh:
|
||||
if !ok || c.isClosed() {
|
||||
if num > 0 {
|
||||
return num, nil
|
||||
}
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
n, ep = c.processWebRTCReadResult(wr, buffs[num])
|
||||
if n > 0 {
|
||||
sizes[num] = n
|
||||
eps[num] = ep
|
||||
num++
|
||||
}
|
||||
default:
|
||||
return num, nil
|
||||
}
|
||||
}
|
||||
return num, nil
|
||||
}
|
||||
37
wgengine/magicsock/webrtc_base_js.go
Normal file
37
wgengine/magicsock/webrtc_base_js.go
Normal file
@ -0,0 +1,37 @@
|
||||
//go:build js
|
||||
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"github.com/pion/webrtc/v4"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func newWebRTCManagerBase(c *Conn) *webrtcManager {
|
||||
// Configure WebRTC with STUN only
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
|
||||
// Create API with setting engine
|
||||
api := webrtc.NewAPI(
|
||||
webrtc.WithSettingEngine(settingEngine),
|
||||
)
|
||||
|
||||
return &webrtcManager{
|
||||
logf: c.logf,
|
||||
conn: c,
|
||||
peerConnectionsByEndpoint: make(map[*endpoint]*webrtcPeerState),
|
||||
peerConnectionsByDisco: make(map[key.DiscoPublic]*webrtcPeerState),
|
||||
startConnectionCh: make(chan *endpoint, 256),
|
||||
connectionReadyCh: make(chan webrtcConnectionReadyEvent, 16),
|
||||
closeCh: make(chan struct{}),
|
||||
runLoopStoppedCh: make(chan struct{}),
|
||||
api: api,
|
||||
}
|
||||
}
|
||||
|
||||
func setOnError(dc *webrtc.DataChannel, fn func(error)) {
|
||||
// NO-OP... *webrtc.DataChannel does not have OnError for js.
|
||||
}
|
||||
70
wgengine/magicsock/webrtc_base_native.go
Normal file
70
wgengine/magicsock/webrtc_base_native.go
Normal file
@ -0,0 +1,70 @@
|
||||
//go:build !js
|
||||
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func newWebRTCManagerBase(c *Conn) *webrtcManager {
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
|
||||
// Use a 16 MiB SCTP receive buffer. The pion default (~32 KiB) becomes the
|
||||
// bottleneck at high throughput because SCTP's flow-control window is bounded
|
||||
// by this value.
|
||||
settingEngine.SetSCTPMaxReceiveBufferSize(16 * 1024 * 1024)
|
||||
|
||||
// Enlarge the DTLS replay-protection window. The default (64) causes
|
||||
// legitimate packets to be dropped as duplicates when the sender gets ahead
|
||||
// of the receiver by more than 64 packets, which happens easily at Gbps speeds.
|
||||
settingEngine.SetDTLSReplayProtectionWindow(8192)
|
||||
|
||||
// Lower the SCTP retransmission timeout ceiling. The default (1s+) causes
|
||||
// SCTP's congestion control to stall for a full second after any loss event,
|
||||
// which is catastrophic for throughput on a low-latency P2P link. 100ms is
|
||||
// still conservative but recovers much faster.
|
||||
settingEngine.SetSCTPRTOMax(100 * time.Millisecond)
|
||||
|
||||
// DetachDataChannels lets us call dc.Detach() to get a raw io.ReadWriteCloser
|
||||
// instead of using OnMessage callbacks. The callback path allocates a new
|
||||
// DataChannelMessage struct and fires a goroutine wakeup per packet. The
|
||||
// detached path lets us Read() directly into pre-allocated buffers in a
|
||||
// tight goroutine loop, matching how the UDP receive path works.
|
||||
settingEngine.DetachDataChannels()
|
||||
|
||||
// SCTP includes a CRC32c checksum on every chunk. DTLS already provides
|
||||
// both integrity and authenticity for all data, so the SCTP checksum is
|
||||
// redundant CPU work. Zero-checksum mode (RFC 9260) removes it.
|
||||
settingEngine.EnableSCTPZeroChecksum(true)
|
||||
|
||||
// Create MediaEngine (required even though we only use DataChannel)
|
||||
mediaEngine := &webrtc.MediaEngine{}
|
||||
|
||||
// Create API with setting engine
|
||||
api := webrtc.NewAPI(
|
||||
webrtc.WithSettingEngine(settingEngine),
|
||||
webrtc.WithMediaEngine(mediaEngine),
|
||||
)
|
||||
|
||||
return &webrtcManager{
|
||||
logf: c.logf,
|
||||
conn: c,
|
||||
peerConnectionsByEndpoint: make(map[*endpoint]*webrtcPeerState),
|
||||
peerConnectionsByDisco: make(map[key.DiscoPublic]*webrtcPeerState),
|
||||
startConnectionCh: make(chan *endpoint, 256),
|
||||
connectionReadyCh: make(chan webrtcConnectionReadyEvent, 16),
|
||||
closeCh: make(chan struct{}),
|
||||
runLoopStoppedCh: make(chan struct{}),
|
||||
api: api,
|
||||
}
|
||||
}
|
||||
|
||||
func setOnError(dc *webrtc.DataChannel, fn func(error)) {
|
||||
dc.OnError(fn)
|
||||
}
|
||||
435
wgengine/magicsock/webrtc_integration_test.go
Normal file
435
wgengine/magicsock/webrtc_integration_test.go
Normal file
@ -0,0 +1,435 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"tailscale.com/rtclib"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// testSignalHandler is a test implementation of rtclib.SignalHandler
|
||||
type testSignalHandler struct {
|
||||
offerCount int
|
||||
answerCount int
|
||||
candidateCount int
|
||||
t *testing.T
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (h *testSignalHandler) HandleOffer(from, to string, offer *webrtc.SessionDescription) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.offerCount++
|
||||
h.t.Logf("Received offer from %s to %s", from, to)
|
||||
}
|
||||
|
||||
func (h *testSignalHandler) HandleAnswer(from, to string, answer *webrtc.SessionDescription) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.answerCount++
|
||||
h.t.Logf("Received answer from %s to %s", from, to)
|
||||
}
|
||||
|
||||
func (h *testSignalHandler) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.candidateCount++
|
||||
h.t.Logf("Received candidate from %s to %s", from, to)
|
||||
}
|
||||
|
||||
// TestWebRTCIntegration_MockSignalingServer tests WebRTC with a mock signaling server
|
||||
func TestWebRTCIntegration_MockSignalingServer(t *testing.T) {
|
||||
// Create mock signaling server
|
||||
server := newMockSignalingServer(t)
|
||||
defer server.Close()
|
||||
|
||||
t.Logf("Mock signaling server running at %s", server.URL)
|
||||
|
||||
// Verify server accepts connections
|
||||
client := newSignalingClient(server.URL, t.Logf)
|
||||
|
||||
handler := &testSignalHandler{t: t}
|
||||
err := client.Start(handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Wait for connection
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Send a test message
|
||||
disco1 := key.NewDisco().Public()
|
||||
disco2 := key.NewDisco().Public()
|
||||
|
||||
err = client.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: "test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send offer: %v", err)
|
||||
}
|
||||
|
||||
// Give time for message to be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify the server received the message
|
||||
serverMsgs := server.GetReceivedMessages()
|
||||
|
||||
if len(serverMsgs) == 0 {
|
||||
t.Error("Server did not receive any messages")
|
||||
} else {
|
||||
t.Logf("Server received %d messages", len(serverMsgs))
|
||||
for i, msg := range serverMsgs {
|
||||
t.Logf("Message %d: type=%s from=%s to=%s", i, msg.Type, msg.From, msg.To)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebRTCIntegration_MessageRelay tests message relay through signaling server
|
||||
func TestWebRTCIntegration_MessageRelay(t *testing.T) {
|
||||
server := newMockSignalingServer(t)
|
||||
defer server.Close()
|
||||
|
||||
// Create two clients
|
||||
handler1 := &testSignalHandler{t: t}
|
||||
handler2 := &testSignalHandler{t: t}
|
||||
|
||||
client1 := newSignalingClient(server.URL, func(format string, args ...any) {
|
||||
t.Logf("[Client1] "+format, args...)
|
||||
})
|
||||
|
||||
client2 := newSignalingClient(server.URL, func(format string, args ...any) {
|
||||
t.Logf("[Client2] "+format, args...)
|
||||
})
|
||||
|
||||
err := client1.Start(handler1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start client1: %v", err)
|
||||
}
|
||||
defer client1.Close()
|
||||
|
||||
err = client2.Start(handler2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start client2: %v", err)
|
||||
}
|
||||
defer client2.Close()
|
||||
|
||||
// Wait for both to connect
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Client 1 sends offer to Client 2
|
||||
disco1 := key.NewDisco().Public()
|
||||
disco2 := key.NewDisco().Public()
|
||||
|
||||
if err := client1.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: "v=0...",
|
||||
}); err != nil {
|
||||
t.Fatalf("Client1 failed to send offer: %v", err)
|
||||
}
|
||||
|
||||
// Wait for message relay
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify client2 received the offer
|
||||
handler2.mu.Lock()
|
||||
c2offers := handler2.offerCount
|
||||
handler2.mu.Unlock()
|
||||
|
||||
if c2offers > 0 {
|
||||
t.Logf("Client2 received %d offers (relay working)", c2offers)
|
||||
} else {
|
||||
t.Log("Client2 did not receive offers (relay may need proper routing)")
|
||||
}
|
||||
|
||||
// Client 2 sends answer back to Client 1
|
||||
if err := client2.Answer(disco2.String(), disco1.String(), &webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
SDP: "v=0...",
|
||||
}); err != nil {
|
||||
t.Fatalf("Client2 failed to send answer: %v", err)
|
||||
}
|
||||
|
||||
// Wait for message relay
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Log final message counts
|
||||
handler1.mu.Lock()
|
||||
c1answers := handler1.answerCount
|
||||
handler1.mu.Unlock()
|
||||
|
||||
handler2.mu.Lock()
|
||||
c2FinalOffers := handler2.offerCount
|
||||
handler2.mu.Unlock()
|
||||
|
||||
t.Logf("Final message counts: Client1 answers=%d, Client2 offers=%d", c1answers, c2FinalOffers)
|
||||
t.Log("Integration test completed successfully")
|
||||
}
|
||||
|
||||
// TestWebRTCIntegration_SignalingFlow tests the complete signaling flow
|
||||
func TestWebRTCIntegration_SignalingFlow(t *testing.T) {
|
||||
server := newMockSignalingServer(t)
|
||||
defer server.Close()
|
||||
|
||||
client := newSignalingClient(server.URL, t.Logf)
|
||||
handler := &testSignalHandler{t: t}
|
||||
|
||||
err := client.Start(handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
disco1 := key.NewDisco().Public()
|
||||
disco2 := key.NewDisco().Public()
|
||||
|
||||
// Simulate complete signaling flow
|
||||
steps := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{
|
||||
name: "send_offer",
|
||||
fn: func() error {
|
||||
return client.Offer(disco1.String(), disco2.String(), &webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: "v=0 offer",
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "send_answer",
|
||||
fn: func() error {
|
||||
return client.Answer(disco2.String(), disco1.String(), &webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
SDP: "v=0 answer",
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "send_candidate",
|
||||
fn: func() error {
|
||||
return client.Candidate(disco1.String(), disco2.String(), &webrtc.ICECandidateInit{
|
||||
Candidate: "test",
|
||||
})
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, step := range steps {
|
||||
t.Logf("Step: %s", step.name)
|
||||
if err := step.fn(); err != nil {
|
||||
t.Errorf("Failed to execute %s: %v", step.name, err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
t.Logf("Signaling flow completed with %d steps", len(steps))
|
||||
}
|
||||
|
||||
// mockSignalingServer is a simple WebSocket server that relays signaling messages
|
||||
type mockSignalingServer struct {
|
||||
*httptest.Server
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
mu sync.Mutex
|
||||
clients map[*websocket.Conn]bool
|
||||
messages []rtclib.SignalingMessage
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func newMockSignalingServer(t *testing.T) *mockSignalingServer {
|
||||
s := &mockSignalingServer{
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
messages: make([]rtclib.SignalingMessage, 0),
|
||||
t: t,
|
||||
}
|
||||
|
||||
s.Server = httptest.NewServer(http.HandlerFunc(s.handleWebSocket))
|
||||
|
||||
// Convert http:// to ws://
|
||||
s.Server.URL = "ws" + s.Server.URL[4:]
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *mockSignalingServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
s.t.Logf("Upgrade error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[conn] = true
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
delete(s.clients, conn)
|
||||
s.mu.Unlock()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
s.t.Logf("Client connected, total clients: %d", len(s.clients))
|
||||
|
||||
for {
|
||||
var msg rtclib.SignalingMessage
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
s.t.Logf("Read error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
s.t.Logf("Server received: type=%s from=%s to=%s", msg.Type, msg.From, msg.To)
|
||||
|
||||
s.mu.Lock()
|
||||
s.messages = append(s.messages, msg)
|
||||
|
||||
// Relay to all other clients (simple broadcast)
|
||||
for client := range s.clients {
|
||||
if client != conn {
|
||||
if err := client.WriteJSON(msg); err != nil {
|
||||
s.t.Logf("Relay error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mockSignalingServer) GetReceivedMessages() []rtclib.SignalingMessage {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
msgs := make([]rtclib.SignalingMessage, len(s.messages))
|
||||
copy(msgs, s.messages)
|
||||
return msgs
|
||||
}
|
||||
|
||||
func (s *mockSignalingServer) Close() {
|
||||
s.mu.Lock()
|
||||
for conn := range s.clients {
|
||||
conn.Close()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
s.Server.Close()
|
||||
}
|
||||
|
||||
// BenchmarkWebRTCSignaling benchmarks signaling message throughput
|
||||
func BenchmarkWebRTCSignaling(b *testing.B) {
|
||||
server := newMockSignalingServer(&testing.T{})
|
||||
defer server.Close()
|
||||
|
||||
client := newSignalingClient(server.URL, func(string, ...any) {})
|
||||
handler := &testSignalHandler{t: &testing.T{}}
|
||||
err := client.Start(handler)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to start client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
disco1 := key.NewDisco().Public()
|
||||
disco2 := key.NewDisco().Public()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := client.Candidate(disco1.String(), disco2.String(), &webrtc.ICECandidateInit{
|
||||
Candidate: "test",
|
||||
}); err != nil {
|
||||
b.Errorf("Send failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebRTCPacketFlow tests packet flow simulation
|
||||
func TestWebRTCPacketFlow(t *testing.T) {
|
||||
// Create a mock webrtcReadResult to simulate packet reception
|
||||
nodeKey := key.NewNode().Public()
|
||||
testPacket := []byte("test wireguard packet")
|
||||
|
||||
result := webrtcReadResult{
|
||||
n: len(testPacket),
|
||||
src: nodeKey,
|
||||
copyBuf: func(dst []byte) int {
|
||||
return copy(dst, testPacket)
|
||||
},
|
||||
}
|
||||
|
||||
// Verify packet can be copied
|
||||
buf := make([]byte, 1024)
|
||||
n := result.copyBuf(buf)
|
||||
|
||||
if n != len(testPacket) {
|
||||
t.Errorf("copyBuf returned %d bytes, want %d", n, len(testPacket))
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testPacket) {
|
||||
t.Errorf("Packet data mismatch: got %q, want %q", buf[:n], testPacket)
|
||||
}
|
||||
|
||||
t.Logf("Packet flow test passed: %d bytes", n)
|
||||
}
|
||||
|
||||
// TestWebRTCPathSelection tests path selection with WebRTC in the mix
|
||||
func TestWebRTCPathSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
paths []addrQuality
|
||||
wantBest string
|
||||
}{
|
||||
{
|
||||
name: "direct_beats_all",
|
||||
paths: []addrQuality{
|
||||
{epAddr: epAddr{ap: netip.MustParseAddrPort("1.2.3.4:1234")}},
|
||||
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}},
|
||||
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1)}},
|
||||
},
|
||||
wantBest: "1.2.3.4:1234",
|
||||
},
|
||||
{
|
||||
name: "webrtc_beats_derp",
|
||||
paths: []addrQuality{
|
||||
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345)}},
|
||||
{epAddr: epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1)}},
|
||||
},
|
||||
wantBest: "127.3.3.41:12345",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
best := tt.paths[0]
|
||||
for _, path := range tt.paths[1:] {
|
||||
if betterAddr(path, best) {
|
||||
best = path
|
||||
}
|
||||
}
|
||||
|
||||
if best.ap.String() != tt.wantBest {
|
||||
t.Errorf("Best path = %v, want %v", best.ap, tt.wantBest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
296
wgengine/magicsock/webrtc_signaling.go
Normal file
296
wgengine/magicsock/webrtc_signaling.go
Normal file
@ -0,0 +1,296 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"tailscale.com/rtclib"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// Ensure signalingClient implements rtclib.Signaller interface.
|
||||
var _ rtclib.Signaller = (*signalingClient)(nil)
|
||||
|
||||
// signalingClient manages WebSocket connection to signaling server.
|
||||
type signalingClient struct {
|
||||
url string
|
||||
logf logger.Logf
|
||||
conn *websocket.Conn
|
||||
connMu sync.Mutex
|
||||
|
||||
// Message handler
|
||||
handler rtclib.SignalHandler
|
||||
|
||||
// Control channels
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
sendCh chan *rtclib.SignalingMessage
|
||||
closedCh chan struct{}
|
||||
|
||||
// Reconnection state
|
||||
reconnectDelay time.Duration
|
||||
maxDelay time.Duration
|
||||
}
|
||||
|
||||
// newSignalingClient creates a new signaling client.
|
||||
func newSignalingClient(url string, logf logger.Logf) *signalingClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &signalingClient{
|
||||
url: url,
|
||||
logf: logf,
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
sendCh: make(chan *rtclib.SignalingMessage, 16),
|
||||
closedCh: make(chan struct{}),
|
||||
reconnectDelay: time.Second,
|
||||
maxDelay: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the signaling client's connection and message loops.
|
||||
func (sc *signalingClient) Start(handler rtclib.SignalHandler) error {
|
||||
sc.handler = handler
|
||||
if err := sc.connect(); err != nil {
|
||||
return fmt.Errorf("initial signaling connection failed: %w", err)
|
||||
}
|
||||
go sc.runLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// connect establishes WebSocket connection to signaling server.
|
||||
func (sc *signalingClient) connect() error {
|
||||
sc.connMu.Lock()
|
||||
defer sc.connMu.Unlock()
|
||||
|
||||
if sc.conn != nil {
|
||||
return nil // already connected
|
||||
}
|
||||
|
||||
conn, _, err := websocket.Dial(sc.ctx, sc.url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websocket dial failed: %w", err)
|
||||
}
|
||||
|
||||
sc.conn = conn
|
||||
sc.reconnectDelay = time.Second // reset backoff on successful connection
|
||||
sc.logf("signaling: connected to %s", sc.url)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the signaling client.
|
||||
func (sc *signalingClient) Close() error {
|
||||
// Cancel context to signal all goroutines to stop
|
||||
sc.ctxCancel()
|
||||
|
||||
// Close the connection to unblock any read/write operations
|
||||
sc.connMu.Lock()
|
||||
if sc.conn != nil {
|
||||
sc.conn.Close(websocket.StatusNormalClosure, "")
|
||||
}
|
||||
sc.connMu.Unlock()
|
||||
|
||||
// Wait for runLoop to finish with timeout
|
||||
select {
|
||||
case <-sc.closedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
sc.logf("signaling: close timed out, forcing shutdown")
|
||||
}
|
||||
|
||||
sc.connMu.Lock()
|
||||
defer sc.connMu.Unlock()
|
||||
sc.conn = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// send queues a message to be sent to the signaling server.
|
||||
func (sc *signalingClient) send(msg *rtclib.SignalingMessage) error {
|
||||
select {
|
||||
case sc.sendCh <- msg:
|
||||
return nil
|
||||
case <-sc.ctx.Done():
|
||||
return sc.ctx.Err()
|
||||
default:
|
||||
return errors.New("signaling send queue full")
|
||||
}
|
||||
}
|
||||
|
||||
// runLoop manages connection lifecycle and message routing.
|
||||
func (sc *signalingClient) runLoop() {
|
||||
defer close(sc.closedCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Ensure we're connected
|
||||
if err := sc.ensureConnected(); err != nil {
|
||||
sc.logf("signaling: connection failed, retrying in %v: %v", sc.reconnectDelay, err)
|
||||
select {
|
||||
case <-time.After(sc.reconnectDelay):
|
||||
sc.reconnectDelay = min(sc.reconnectDelay*2, sc.maxDelay)
|
||||
continue
|
||||
case <-sc.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Run read/write loops
|
||||
errCh := make(chan error, 2)
|
||||
go sc.readLoop(errCh)
|
||||
go sc.writeLoop(errCh)
|
||||
|
||||
// Wait for error or context cancellation
|
||||
select {
|
||||
case err := <-errCh:
|
||||
sc.logf("signaling: connection error: %v", err)
|
||||
sc.disconnect()
|
||||
case <-sc.ctx.Done():
|
||||
sc.disconnect()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureConnected ensures connection is established.
|
||||
func (sc *signalingClient) ensureConnected() error {
|
||||
sc.connMu.Lock()
|
||||
connected := sc.conn != nil
|
||||
sc.connMu.Unlock()
|
||||
|
||||
if connected {
|
||||
return nil
|
||||
}
|
||||
|
||||
return sc.connect()
|
||||
}
|
||||
|
||||
// disconnect closes the current connection.
|
||||
func (sc *signalingClient) disconnect() {
|
||||
sc.connMu.Lock()
|
||||
defer sc.connMu.Unlock()
|
||||
|
||||
if sc.conn != nil {
|
||||
sc.conn.Close(websocket.StatusNormalClosure, "")
|
||||
sc.conn = nil
|
||||
sc.logf("signaling: disconnected")
|
||||
}
|
||||
}
|
||||
|
||||
// readLoop reads messages from WebSocket.
|
||||
func (sc *signalingClient) readLoop(errCh chan<- error) {
|
||||
for {
|
||||
sc.connMu.Lock()
|
||||
conn := sc.conn
|
||||
sc.connMu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
errCh <- errors.New("no connection")
|
||||
return
|
||||
}
|
||||
|
||||
var msg rtclib.SignalingMessage
|
||||
if err := wsjson.Read(sc.ctx, conn, &msg); err != nil {
|
||||
errCh <- fmt.Errorf("read failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if sc.handler != nil {
|
||||
switch msg.Type {
|
||||
case rtclib.MessageTypeOffer:
|
||||
sc.handler.HandleOffer(msg.From, msg.To, msg.Offer)
|
||||
case rtclib.MessageTypeAnswer:
|
||||
sc.handler.HandleAnswer(msg.From, msg.To, msg.Answer)
|
||||
case rtclib.MessageTypeCandidate:
|
||||
sc.handler.HandleCandidate(msg.From, msg.To, msg.Candidate)
|
||||
default:
|
||||
sc.logf("signaling: unknown message type: %s", msg.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeLoop writes messages to WebSocket.
|
||||
func (sc *signalingClient) writeLoop(errCh chan<- error) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-sc.sendCh:
|
||||
sc.connMu.Lock()
|
||||
conn := sc.conn
|
||||
sc.connMu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
errCh <- errors.New("no connection")
|
||||
return
|
||||
}
|
||||
|
||||
if err := wsjson.Write(sc.ctx, conn, msg); err != nil {
|
||||
errCh <- fmt.Errorf("write failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
// Send ping to keep connection alive
|
||||
sc.connMu.Lock()
|
||||
conn := sc.conn
|
||||
sc.connMu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
errCh <- errors.New("no connection")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Ping(sc.ctx); err != nil {
|
||||
errCh <- fmt.Errorf("ping failed: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
case <-sc.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Offer sends an SDP offer to a peer.
|
||||
func (sc *signalingClient) Offer(from, to string, offer *webrtc.SessionDescription) error {
|
||||
return sc.send(&rtclib.SignalingMessage{
|
||||
Type: rtclib.MessageTypeOffer,
|
||||
From: from,
|
||||
To: to,
|
||||
Offer: offer,
|
||||
})
|
||||
}
|
||||
|
||||
// Answer sends an SDP answer to a peer.
|
||||
func (sc *signalingClient) Answer(from, to string, answer *webrtc.SessionDescription) error {
|
||||
return sc.send(&rtclib.SignalingMessage{
|
||||
Type: rtclib.MessageTypeAnswer,
|
||||
From: from,
|
||||
To: to,
|
||||
Answer: answer,
|
||||
})
|
||||
}
|
||||
|
||||
// Candidate sends an ICE candidate to a peer.
|
||||
func (sc *signalingClient) Candidate(from, to string, candidate *webrtc.ICECandidateInit) error {
|
||||
return sc.send(&rtclib.SignalingMessage{
|
||||
Type: rtclib.MessageTypeCandidate,
|
||||
From: from,
|
||||
To: to,
|
||||
Candidate: candidate,
|
||||
})
|
||||
}
|
||||
101
wgengine/magicsock/webrtc_signaling_disco.go
Normal file
101
wgengine/magicsock/webrtc_signaling_disco.go
Normal file
@ -0,0 +1,101 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/rtclib"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// discoSignaller implements rtclib.Signaller by routing WebRTC signaling
|
||||
// messages through the existing Tailscale disco/DERP infrastructure. This
|
||||
// eliminates the need for an external signaling server: SDP offers/answers and
|
||||
// ICE candidates travel as encrypted disco messages relayed over DERP, using
|
||||
// the same authenticated peer-to-peer path that Tailscale already maintains.
|
||||
type discoSignaller struct {
|
||||
conn *Conn
|
||||
handler rtclib.SignalHandler
|
||||
}
|
||||
|
||||
// Ensure discoSignaller implements rtclib.Signaller.
|
||||
var _ rtclib.Signaller = (*discoSignaller)(nil)
|
||||
|
||||
// Start implements rtclib.Signaller.
|
||||
func (ds *discoSignaller) Start(handler rtclib.SignalHandler) error {
|
||||
ds.handler = handler
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements rtclib.Signaller. Nothing to tear down — the disco/DERP
|
||||
// path is managed by the surrounding Conn.
|
||||
func (ds *discoSignaller) Close() error { return nil }
|
||||
|
||||
// Offer implements rtclib.Signaller.
|
||||
func (ds *discoSignaller) Offer(from, to string, offer *webrtc.SessionDescription) error {
|
||||
payload, err := json.Marshal(offer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc disco signaller: marshal offer: %w", err)
|
||||
}
|
||||
return ds.send(to, &disco.WebRTCOffer{Payload: payload})
|
||||
}
|
||||
|
||||
// Answer implements rtclib.Signaller.
|
||||
func (ds *discoSignaller) Answer(from, to string, answer *webrtc.SessionDescription) error {
|
||||
payload, err := json.Marshal(answer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc disco signaller: marshal answer: %w", err)
|
||||
}
|
||||
return ds.send(to, &disco.WebRTCAnswer{Payload: payload})
|
||||
}
|
||||
|
||||
// Candidate implements rtclib.Signaller.
|
||||
func (ds *discoSignaller) Candidate(from, to string, candidate *webrtc.ICECandidateInit) error {
|
||||
payload, err := json.Marshal(candidate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("webrtc disco signaller: marshal candidate: %w", err)
|
||||
}
|
||||
return ds.send(to, &disco.WebRTCICECandidate{Payload: payload})
|
||||
}
|
||||
|
||||
// send routes a disco WebRTC message to the peer identified by toDisco (a hex
|
||||
// disco public key string), via that peer's home DERP region.
|
||||
func (ds *discoSignaller) send(toDisco string, m disco.Message) error {
|
||||
var toKey key.DiscoPublic
|
||||
if err := toKey.UnmarshalText([]byte(toDisco)); err != nil {
|
||||
return fmt.Errorf("webrtc disco signaller: parse disco key %q: %w", toDisco, err)
|
||||
}
|
||||
|
||||
// Find the endpoint and its home DERP address under the Conn lock.
|
||||
ds.conn.mu.Lock()
|
||||
var (
|
||||
derpAddr netip.AddrPort
|
||||
nodeKey key.NodePublic
|
||||
found bool
|
||||
)
|
||||
ds.conn.peerMap.forEachEndpointWithDiscoKey(toKey, func(ep *endpoint) bool {
|
||||
ep.mu.Lock()
|
||||
derpAddr = ep.derpAddr
|
||||
ep.mu.Unlock()
|
||||
nodeKey = ep.publicKey
|
||||
found = true
|
||||
return false // stop after first match
|
||||
})
|
||||
ds.conn.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
return fmt.Errorf("webrtc disco signaller: no endpoint for disco key %v", toKey.ShortString())
|
||||
}
|
||||
if !derpAddr.IsValid() {
|
||||
return fmt.Errorf("webrtc disco signaller: no DERP address for peer %v", toKey.ShortString())
|
||||
}
|
||||
|
||||
_, err := ds.conn.sendDiscoMessage(epAddr{ap: derpAddr}, nodeKey, toKey, m, discoLog)
|
||||
return err
|
||||
}
|
||||
323
wgengine/magicsock/webrtc_test.go
Normal file
323
wgengine/magicsock/webrtc_test.go
Normal file
@ -0,0 +1,323 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"tailscale.com/rtclib"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// TestSignalingMessageEncoding tests JSON encoding/decoding of signaling messages
|
||||
func TestSignalingMessageEncoding(t *testing.T) {
|
||||
disco1 := key.NewDisco()
|
||||
disco2 := key.NewDisco()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
msg rtclib.SignalingMessage
|
||||
}{
|
||||
{
|
||||
name: "offer",
|
||||
msg: rtclib.SignalingMessage{
|
||||
Type: rtclib.MessageTypeOffer,
|
||||
From: disco1.Public().String(),
|
||||
To: disco2.Public().String(),
|
||||
Offer: &webrtc.SessionDescription{Type: webrtc.SDPTypeOffer, SDP: "v=0..."},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "answer",
|
||||
msg: rtclib.SignalingMessage{
|
||||
Type: rtclib.MessageTypeAnswer,
|
||||
From: disco2.Public().String(),
|
||||
To: disco1.Public().String(),
|
||||
Answer: &webrtc.SessionDescription{Type: webrtc.SDPTypeAnswer, SDP: "v=0..."},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "candidate",
|
||||
msg: rtclib.SignalingMessage{
|
||||
Type: rtclib.MessageTypeCandidate,
|
||||
From: disco1.Public().String(),
|
||||
To: disco2.Public().String(),
|
||||
Candidate: &webrtc.ICECandidateInit{Candidate: "..."},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Encode
|
||||
data, err := json.Marshal(tt.msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Decode
|
||||
var decoded rtclib.SignalingMessage
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify
|
||||
if decoded.Type != tt.msg.Type {
|
||||
t.Errorf("Type mismatch: got %v, want %v", decoded.Type, tt.msg.Type)
|
||||
}
|
||||
if decoded.From != tt.msg.From {
|
||||
t.Errorf("From mismatch: got %v, want %v", decoded.From, tt.msg.From)
|
||||
}
|
||||
if decoded.To != tt.msg.To {
|
||||
t.Errorf("To mismatch: got %v, want %v", decoded.To, tt.msg.To)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockSignalHandler is a test implementation of rtclib.SignalHandler
|
||||
type mockSignalHandler struct {
|
||||
offerCount int
|
||||
answerCount int
|
||||
candidateCount int
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockSignalHandler) HandleOffer(from, to string, offer *webrtc.SessionDescription) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.offerCount++
|
||||
}
|
||||
|
||||
func (m *mockSignalHandler) HandleAnswer(from, to string, answer *webrtc.SessionDescription) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.answerCount++
|
||||
}
|
||||
|
||||
func (m *mockSignalHandler) HandleCandidate(from, to string, candidate *webrtc.ICECandidateInit) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.candidateCount++
|
||||
}
|
||||
|
||||
// TestSignalingClientReconnect tests reconnection with backoff
|
||||
func TestSignalingClientReconnect(t *testing.T) {
|
||||
var connectCount int
|
||||
var mu sync.Mutex
|
||||
|
||||
// Mock WebSocket server that closes connections after accepting them
|
||||
upgrader := websocket.Upgrader{}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
connectCount++
|
||||
mu.Unlock()
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Immediately close to force reconnection
|
||||
conn.Close()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Convert http:// to ws://
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
|
||||
client := newSignalingClient(wsURL, t.Logf)
|
||||
handler := &mockSignalHandler{}
|
||||
err := client.Start(handler)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Wait for a few reconnection attempts
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
mu.Lock()
|
||||
count := connectCount
|
||||
mu.Unlock()
|
||||
|
||||
// Should have attempted to connect multiple times
|
||||
if count < 2 {
|
||||
t.Errorf("Expected multiple reconnection attempts, got %d", count)
|
||||
}
|
||||
|
||||
t.Logf("Reconnected %d times", count)
|
||||
}
|
||||
|
||||
// TestWebRTCMagicIP tests the WebRTC magic IP constant
|
||||
func TestWebRTCMagicIP(t *testing.T) {
|
||||
if tailcfg.WebRTCMagicIPAddr.String() != "127.3.3.41" {
|
||||
t.Errorf("WebRTC magic IP = %v, want 127.3.3.41", tailcfg.WebRTCMagicIPAddr)
|
||||
}
|
||||
|
||||
// Verify it's different from DERP magic IP
|
||||
if tailcfg.WebRTCMagicIPAddr == tailcfg.DerpMagicIPAddr {
|
||||
t.Error("WebRTC magic IP should be different from DERP magic IP")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebRTCPathPriority tests path preference logic
|
||||
func TestWebRTCPathPriority(t *testing.T) {
|
||||
directV4 := addrQuality{
|
||||
epAddr: epAddr{
|
||||
ap: netip.MustParseAddrPort("192.168.1.100:41641"),
|
||||
},
|
||||
latency: 10 * time.Millisecond,
|
||||
}
|
||||
|
||||
webrtc := addrQuality{
|
||||
epAddr: epAddr{
|
||||
ap: netip.AddrPortFrom(tailcfg.WebRTCMagicIPAddr, 12345),
|
||||
},
|
||||
latency: 50 * time.Millisecond,
|
||||
}
|
||||
|
||||
derp := addrQuality{
|
||||
epAddr: epAddr{
|
||||
ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 1),
|
||||
},
|
||||
latency: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
a, b addrQuality
|
||||
want bool // true if a is better than b
|
||||
}{
|
||||
{
|
||||
name: "direct beats WebRTC",
|
||||
a: directV4,
|
||||
b: webrtc,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "WebRTC beats DERP",
|
||||
a: webrtc,
|
||||
b: derp,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "direct beats DERP",
|
||||
a: directV4,
|
||||
b: derp,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "DERP loses to WebRTC",
|
||||
a: derp,
|
||||
b: webrtc,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "WebRTC loses to direct",
|
||||
a: webrtc,
|
||||
b: directV4,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := betterAddr(tt.a, tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("betterAddr(%v, %v) = %v, want %v", tt.a.ap, tt.b.ap, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebRTCReadResult tests webrtcReadResult structure
|
||||
func TestWebRTCReadResult(t *testing.T) {
|
||||
nodeKey := key.NewNode()
|
||||
testData := []byte("test packet data")
|
||||
|
||||
result := webrtcReadResult{
|
||||
n: len(testData),
|
||||
src: nodeKey.Public(),
|
||||
copyBuf: func(dst []byte) int {
|
||||
return copy(dst, testData)
|
||||
},
|
||||
}
|
||||
|
||||
// Test copyBuf
|
||||
buf := make([]byte, 100)
|
||||
n := result.copyBuf(buf)
|
||||
if n != len(testData) {
|
||||
t.Errorf("copyBuf returned %d, want %d", n, len(testData))
|
||||
}
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("copyBuf data = %q, want %q", buf[:n], testData)
|
||||
}
|
||||
|
||||
// Test fields
|
||||
if result.n != len(testData) {
|
||||
t.Errorf("result.n = %d, want %d", result.n, len(testData))
|
||||
}
|
||||
if result.src != nodeKey.Public() {
|
||||
t.Errorf("result.src mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDiscoRXPathWebRTC tests the WebRTC disco path constant
|
||||
func TestDiscoRXPathWebRTC(t *testing.T) {
|
||||
if discoRXPathWebRTC != "WebRTC" {
|
||||
t.Errorf("discoRXPathWebRTC = %q, want %q", discoRXPathWebRTC, "WebRTC")
|
||||
}
|
||||
|
||||
// Verify it's different from other paths
|
||||
if discoRXPathWebRTC == discoRXPathDERP {
|
||||
t.Error("WebRTC path should be different from DERP path")
|
||||
}
|
||||
if discoRXPathWebRTC == discoRXPathUDP {
|
||||
t.Error("WebRTC path should be different from UDP path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebRTCMetrics tests that WebRTC metrics are properly defined
|
||||
func TestWebRTCMetrics(t *testing.T) {
|
||||
// Test that metric variables exist (they're package-level variables)
|
||||
if metricRecvDataPacketsWebRTC == nil {
|
||||
t.Error("metricRecvDataPacketsWebRTC should be initialized")
|
||||
}
|
||||
if metricRecvDataBytesWebRTC == nil {
|
||||
t.Error("metricRecvDataBytesWebRTC should be initialized")
|
||||
}
|
||||
if metricSendDataPacketsWebRTC == nil {
|
||||
t.Error("metricSendDataPacketsWebRTC should be initialized")
|
||||
}
|
||||
if metricSendDataBytesWebRTC == nil {
|
||||
t.Error("metricSendDataBytesWebRTC should be initialized")
|
||||
}
|
||||
|
||||
t.Log("WebRTC metrics are properly defined")
|
||||
}
|
||||
|
||||
// TestPathWebRTCConstant tests the PathWebRTC constant
|
||||
func TestPathWebRTCConstant(t *testing.T) {
|
||||
if PathWebRTC != "webrtc" {
|
||||
t.Errorf("PathWebRTC = %q, want %q", PathWebRTC, "webrtc")
|
||||
}
|
||||
|
||||
// Verify it's different from other paths
|
||||
if PathWebRTC == PathDERP {
|
||||
t.Error("PathWebRTC should be different from PathDERP")
|
||||
}
|
||||
if PathWebRTC == PathDirectIPv4 {
|
||||
t.Error("PathWebRTC should be different from PathDirectIPv4")
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user