mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-11-04 10:11:18 +01:00 
			
		
		
		
	If you had HTTPS_PROXY=https://some-valid-cert.example.com running a CONNECT proxy, we should've been able to do a TLS CONNECT request to e.g. controlplane.tailscale.com:443 through that, and I'm pretty sure it used to work, but refactorings and lack of integration tests made it regress. It probably regressed when we added the baked-in LetsEncrypt root cert validation fallback code, which was testing against the wrong hostname (the ultimate one, not the one which we were being asked to validate) Fixes #16222 Change-Id: If014e395f830e2f87f056f588edacad5c15e91bc Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
		
			
				
	
	
		
			94 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			94 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) Tailscale Inc & AUTHORS
 | 
						|
// SPDX-License-Identifier: BSD-3-Clause
 | 
						|
 | 
						|
// Package connectproxy contains some CONNECT proxy code.
 | 
						|
package connectproxy
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"io"
 | 
						|
	"log"
 | 
						|
	"net"
 | 
						|
	"net/http"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"tailscale.com/net/netx"
 | 
						|
	"tailscale.com/types/logger"
 | 
						|
)
 | 
						|
 | 
						|
// Handler is an HTTP CONNECT proxy handler.
 | 
						|
type Handler struct {
 | 
						|
	// Dial, if non-nil, is an alternate dialer to use
 | 
						|
	// instead of the default dialer.
 | 
						|
	Dial netx.DialFunc
 | 
						|
 | 
						|
	// Logf, if non-nil, is an alterate logger to
 | 
						|
	// use instead of log.Printf.
 | 
						|
	Logf logger.Logf
 | 
						|
 | 
						|
	// Check, if non-nil, validates the CONNECT target.
 | 
						|
	Check func(hostPort string) error
 | 
						|
}
 | 
						|
 | 
						|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
						|
	ctx := r.Context()
 | 
						|
	if r.Method != "CONNECT" {
 | 
						|
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	dial := h.Dial
 | 
						|
	if dial == nil {
 | 
						|
		var d net.Dialer
 | 
						|
		dial = d.DialContext
 | 
						|
	}
 | 
						|
	logf := h.Logf
 | 
						|
	if logf == nil {
 | 
						|
		logf = log.Printf
 | 
						|
	}
 | 
						|
 | 
						|
	hostPort := r.RequestURI
 | 
						|
	if h.Check != nil {
 | 
						|
		if err := h.Check(hostPort); err != nil {
 | 
						|
			logf("CONNECT target %q not allowed: %v", hostPort, err)
 | 
						|
			http.Error(w, "Invalid CONNECT target", http.StatusForbidden)
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
 | 
						|
	defer cancel()
 | 
						|
	back, err := dial(ctx, "tcp", hostPort)
 | 
						|
	if err != nil {
 | 
						|
		logf("error CONNECT dialing %v: %v", hostPort, err)
 | 
						|
		http.Error(w, "Connect failure", http.StatusBadGateway)
 | 
						|
		return
 | 
						|
	}
 | 
						|
	defer back.Close()
 | 
						|
 | 
						|
	hj, ok := w.(http.Hijacker)
 | 
						|
	if !ok {
 | 
						|
		http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError)
 | 
						|
		return
 | 
						|
	}
 | 
						|
	c, br, err := hj.Hijack()
 | 
						|
	if err != nil {
 | 
						|
		logf("CONNECT hijack: %v", err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
	defer c.Close()
 | 
						|
 | 
						|
	io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
 | 
						|
 | 
						|
	errc := make(chan error, 2)
 | 
						|
	go func() {
 | 
						|
		_, err := io.Copy(c, back)
 | 
						|
		errc <- err
 | 
						|
	}()
 | 
						|
	go func() {
 | 
						|
		_, err := io.Copy(back, br)
 | 
						|
		errc <- err
 | 
						|
	}()
 | 
						|
	<-errc
 | 
						|
}
 |