mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-26 05:41:04 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			1000 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			1000 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| // Package tsweb contains code used in various Tailscale webservers.
 | |
| package tsweb
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"bytes"
 | |
| 	"cmp"
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"expvar"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/netip"
 | |
| 	"net/url"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"regexp"
 | |
| 	"runtime"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"go4.org/mem"
 | |
| 	"tailscale.com/envknob"
 | |
| 	"tailscale.com/metrics"
 | |
| 	"tailscale.com/net/tsaddr"
 | |
| 	"tailscale.com/tsweb/varz"
 | |
| 	"tailscale.com/types/logger"
 | |
| 	"tailscale.com/util/ctxkey"
 | |
| 	"tailscale.com/util/vizerror"
 | |
| )
 | |
| 
 | |
| // DevMode controls whether extra output in shown, for when the binary is being run in dev mode.
 | |
| var DevMode bool
 | |
| 
 | |
| func DefaultCertDir(leafDir string) string {
 | |
| 	cacheDir, err := os.UserCacheDir()
 | |
| 	if err == nil {
 | |
| 		return filepath.Join(cacheDir, "tailscale", leafDir)
 | |
| 	}
 | |
| 	return ""
 | |
| }
 | |
| 
 | |
| // IsProd443 reports whether addr is a Go listen address for port 443.
 | |
| func IsProd443(addr string) bool {
 | |
| 	_, port, _ := net.SplitHostPort(addr)
 | |
| 	return port == "443" || port == "https"
 | |
| }
 | |
| 
 | |
| // AllowDebugAccess reports whether r should be permitted to access
 | |
| // various debug endpoints.
 | |
| func AllowDebugAccess(r *http.Request) bool {
 | |
| 	if allowDebugAccessWithKey(r) {
 | |
| 		return true
 | |
| 	}
 | |
| 	if r.Header.Get("X-Forwarded-For") != "" {
 | |
| 		// TODO if/when needed. For now, conservative:
 | |
| 		return false
 | |
| 	}
 | |
| 	ipStr, _, err := net.SplitHostPort(r.RemoteAddr)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	ip, err := netip.ParseAddr(ipStr)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	if tsaddr.IsTailscaleIP(ip) || ip.IsLoopback() || ipStr == envknob.String("TS_ALLOW_DEBUG_IP") {
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func allowDebugAccessWithKey(r *http.Request) bool {
 | |
| 	if r.Method != "GET" {
 | |
| 		return false
 | |
| 	}
 | |
| 	urlKey := r.FormValue("debugkey")
 | |
| 	keyPath := envknob.String("TS_DEBUG_KEY_PATH")
 | |
| 	if urlKey != "" && keyPath != "" {
 | |
| 		slurp, err := os.ReadFile(keyPath)
 | |
| 		if err == nil && string(bytes.TrimSpace(slurp)) == urlKey {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // AcceptsEncoding reports whether r accepts the named encoding
 | |
| // ("gzip", "br", etc).
 | |
| func AcceptsEncoding(r *http.Request, enc string) bool {
 | |
| 	h := r.Header.Get("Accept-Encoding")
 | |
| 	if h == "" {
 | |
| 		return false
 | |
| 	}
 | |
| 	if !strings.Contains(h, enc) && !mem.ContainsFold(mem.S(h), mem.S(enc)) {
 | |
| 		return false
 | |
| 	}
 | |
| 	remain := h
 | |
| 	for len(remain) > 0 {
 | |
| 		var part string
 | |
| 		part, remain, _ = strings.Cut(remain, ",")
 | |
| 		part = strings.TrimSpace(part)
 | |
| 		part, _, _ = strings.Cut(part, ";")
 | |
| 		if part == enc {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // Protected wraps a provided debug handler, h, returning a Handler
 | |
| // that enforces AllowDebugAccess and returns forbidden replies for
 | |
| // unauthorized requests.
 | |
| func Protected(h http.Handler) http.Handler {
 | |
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		if !AllowDebugAccess(r) {
 | |
| 			msg := "debug access denied"
 | |
| 			if DevMode {
 | |
| 				ipStr, _, _ := net.SplitHostPort(r.RemoteAddr)
 | |
| 				msg += fmt.Sprintf("; to permit access, set TS_ALLOW_DEBUG_IP=%v", ipStr)
 | |
| 			}
 | |
| 			http.Error(w, msg, http.StatusForbidden)
 | |
| 			return
 | |
| 		}
 | |
| 		h.ServeHTTP(w, r)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // Port80Handler is the handler to be given to
 | |
| // autocert.Manager.HTTPHandler.  The inner handler is the mux
 | |
| // returned by NewMux containing registered /debug handlers.
 | |
| type Port80Handler struct {
 | |
| 	Main http.Handler
 | |
| 	// FQDN is used to redirect incoming requests to https://<FQDN>.
 | |
| 	// If it is not set, the hostname is calculated from the incoming
 | |
| 	// request.
 | |
| 	FQDN string
 | |
| }
 | |
| 
 | |
| func (h Port80Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | |
| 	path := r.RequestURI
 | |
| 	if path == "/debug" || strings.HasPrefix(path, "/debug") {
 | |
| 		h.Main.ServeHTTP(w, r)
 | |
| 		return
 | |
| 	}
 | |
| 	if r.Method != "GET" && r.Method != "HEAD" {
 | |
| 		http.Error(w, "Use HTTPS", http.StatusBadRequest)
 | |
| 		return
 | |
| 	}
 | |
| 	if path == "/" && AllowDebugAccess(r) {
 | |
| 		// Redirect authorized user to the debug handler.
 | |
| 		path = "/debug/"
 | |
| 	}
 | |
| 	host := cmp.Or(h.FQDN, r.Host)
 | |
| 	target := "https://" + host + path
 | |
| 	http.Redirect(w, r, target, http.StatusFound)
 | |
| }
 | |
| 
 | |
| // ReturnHandler is like net/http.Handler, but the handler can return an
 | |
| // error instead of writing to its ResponseWriter.
 | |
| type ReturnHandler interface {
 | |
| 	// ServeHTTPReturn is like http.Handler.ServeHTTP, except that
 | |
| 	// it can choose to return an error instead of writing to its
 | |
| 	// http.ResponseWriter.
 | |
| 	//
 | |
| 	// If ServeHTTPReturn returns an error, it caller should handle
 | |
| 	// an error by serving an HTTP 500 response to the user. The
 | |
| 	// error details should not be sent to the client, as they may
 | |
| 	// contain sensitive information. If the error is an
 | |
| 	// HTTPError, though, callers should use the HTTP response
 | |
| 	// code and message as the response to the client.
 | |
| 	ServeHTTPReturn(http.ResponseWriter, *http.Request) error
 | |
| }
 | |
| 
 | |
| // BucketedStatsOptions describes tsweb handler options surrounding
 | |
| // the generation of metrics, grouped into buckets.
 | |
| type BucketedStatsOptions struct {
 | |
| 	// Bucket returns which bucket the given request is in.
 | |
| 	// If nil, [NormalizedPath] is used to compute the bucket.
 | |
| 	Bucket func(req *http.Request) string
 | |
| 
 | |
| 	// If non-nil, Started maintains a counter of all requests which
 | |
| 	// have begun processing.
 | |
| 	Started *metrics.LabelMap
 | |
| 
 | |
| 	// If non-nil, Finished maintains a counter of all requests which
 | |
| 	// have finished processing with success (that is, the HTTP handler has
 | |
| 	// returned).
 | |
| 	Finished *metrics.LabelMap
 | |
| }
 | |
| 
 | |
| // normalizePathRegex matches components in a HTTP request path
 | |
| // that should be replaced.
 | |
| //
 | |
| // See: https://regex101.com/r/WIfpaR/3 for the explainer and test cases.
 | |
| var normalizePathRegex = regexp.MustCompile("([a-fA-F0-9]{9,}|([^\\/])+\\.([^\\/]){2,}|((n|k|u|L|t|S)[a-zA-Z0-9]{5,}(CNTRL|Djz1H|LV5CY|mxgaY|jNy1b))|(([^\\/])+\\@passkey))")
 | |
| 
 | |
| // NormalizedPath returns the given path with the following modifications:
 | |
| //
 | |
| //   - any query parameters are removed
 | |
| //   - any path component with a hex string of 9 or more characters is
 | |
| //     replaced by an ellipsis
 | |
| //   - any path component containing a period with at least two characters
 | |
| //     after the period (i.e. an email or domain)
 | |
| //   - any path component consisting of a common Tailscale Stable ID
 | |
| //   - any path segment *@passkey.
 | |
| func NormalizedPath(p string) string {
 | |
| 	// Fastpath: No hex sequences in there we might have to trim.
 | |
| 	// Avoids allocating.
 | |
| 	if normalizePathRegex.FindStringIndex(p) == nil {
 | |
| 		b, _, _ := strings.Cut(p, "?")
 | |
| 		return b
 | |
| 	}
 | |
| 
 | |
| 	// If we got here, there's at least one hex sequences we need to
 | |
| 	// replace with an ellipsis.
 | |
| 	replaced := normalizePathRegex.ReplaceAllString(p, "…")
 | |
| 	b, _, _ := strings.Cut(replaced, "?")
 | |
| 	return b
 | |
| }
 | |
| 
 | |
| func (o *BucketedStatsOptions) bucketForRequest(r *http.Request) string {
 | |
| 	if o.Bucket != nil {
 | |
| 		return o.Bucket(r)
 | |
| 	}
 | |
| 
 | |
| 	return NormalizedPath(r.URL.Path)
 | |
| }
 | |
| 
 | |
| // HandlerOptions are options used by [StdHandler], containing both [LogOptions]
 | |
| // used by [LogHandler] and [ErrorOptions] used by [ErrorHandler].
 | |
| type HandlerOptions struct {
 | |
| 	QuietLoggingIfSuccessful bool // if set, do not log successfully handled HTTP requests (200 and 304 status codes)
 | |
| 	Logf                     logger.Logf
 | |
| 	Now                      func() time.Time // if nil, defaults to time.Now
 | |
| 
 | |
| 	// If non-nil, StatusCodeCounters maintains counters
 | |
| 	// of status codes for handled responses.
 | |
| 	// The keys are "1xx", "2xx", "3xx", "4xx", and "5xx".
 | |
| 	StatusCodeCounters *expvar.Map
 | |
| 	// If non-nil, StatusCodeCountersFull maintains counters of status
 | |
| 	// codes for handled responses.
 | |
| 	// The keys are HTTP numeric response codes e.g. 200, 404, ...
 | |
| 	StatusCodeCountersFull *expvar.Map
 | |
| 
 | |
| 	// If non-nil, BucketedStats computes and exposes statistics
 | |
| 	// for each bucket based on the contained parameters.
 | |
| 	BucketedStats *BucketedStatsOptions
 | |
| 
 | |
| 	// OnStart is called inline before ServeHTTP is called. Optional.
 | |
| 	OnStart OnStartFunc
 | |
| 
 | |
| 	// OnError is called if the handler returned a HTTPError. This
 | |
| 	// is intended to be used to present pretty error pages if
 | |
| 	// the user agent is determined to be a browser.
 | |
| 	OnError ErrorHandlerFunc
 | |
| 
 | |
| 	// OnCompletion is called inline when ServeHTTP is finished and gets
 | |
| 	// useful data that the implementor can use for metrics. Optional.
 | |
| 	OnCompletion OnCompletionFunc
 | |
| }
 | |
| 
 | |
| // LogOptions are the options used by [LogHandler].
 | |
| // These options are a subset of [HandlerOptions].
 | |
| type LogOptions struct {
 | |
| 	// Logf is used to log HTTP requests and responses.
 | |
| 	Logf logger.Logf
 | |
| 	// Now is a function giving the current time. Defaults to [time.Now].
 | |
| 	Now func() time.Time
 | |
| 
 | |
| 	// QuietLogging suppresses all logging of handled HTTP requests, even if
 | |
| 	// there are errors or status codes considered unsuccessful. Use this option
 | |
| 	// to add your own logging in OnCompletion.
 | |
| 	QuietLogging bool
 | |
| 	// QuietLoggingIfSuccessful suppresses logging of handled HTTP requests
 | |
| 	// where the request's response status code is 200 or 304.
 | |
| 	QuietLoggingIfSuccessful bool
 | |
| 
 | |
| 	// StatusCodeCounters maintains counters of status code classes.
 | |
| 	// The keys are "1xx", "2xx", "3xx", "4xx", and "5xx".
 | |
| 	// If nil, no counting is done.
 | |
| 	StatusCodeCounters *expvar.Map
 | |
| 	// StatusCodeCountersFull maintains counters of status codes.
 | |
| 	// The keys are HTTP numeric response codes e.g. 200, 404, ...
 | |
| 	// If nil, no counting is done.
 | |
| 	StatusCodeCountersFull *expvar.Map
 | |
| 	// BucketedStats computes and exposes statistics for each bucket based on
 | |
| 	// the contained parameters. If nil, no counting is done.
 | |
| 	BucketedStats *BucketedStatsOptions
 | |
| 
 | |
| 	// OnStart is called inline before ServeHTTP is called. Optional.
 | |
| 	OnStart OnStartFunc
 | |
| 	// OnCompletion is called inline when ServeHTTP is finished and gets
 | |
| 	// useful data that the implementor can use for metrics. Optional.
 | |
| 	OnCompletion OnCompletionFunc
 | |
| }
 | |
| 
 | |
| func (o HandlerOptions) logOptions() LogOptions {
 | |
| 	return LogOptions{
 | |
| 		QuietLoggingIfSuccessful: o.QuietLoggingIfSuccessful,
 | |
| 		Logf:                     o.Logf,
 | |
| 		Now:                      o.Now,
 | |
| 		StatusCodeCounters:       o.StatusCodeCounters,
 | |
| 		StatusCodeCountersFull:   o.StatusCodeCountersFull,
 | |
| 		BucketedStats:            o.BucketedStats,
 | |
| 		OnStart:                  o.OnStart,
 | |
| 		OnCompletion:             o.OnCompletion,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (opts LogOptions) withDefaults() LogOptions {
 | |
| 	if opts.Logf == nil {
 | |
| 		opts.Logf = logger.Discard
 | |
| 	}
 | |
| 	if opts.Now == nil {
 | |
| 		opts.Now = time.Now
 | |
| 	}
 | |
| 	return opts
 | |
| }
 | |
| 
 | |
| // ErrorOptions are options used by [ErrorHandler].
 | |
| type ErrorOptions struct {
 | |
| 	// Logf is used to record unexpected behaviours when returning HTTPError but
 | |
| 	// different error codes have already been written to the client.
 | |
| 	Logf logger.Logf
 | |
| 	// OnError is called if the handler returned a HTTPError. This
 | |
| 	// is intended to be used to present pretty error pages if
 | |
| 	// the user agent is determined to be a browser.
 | |
| 	OnError ErrorHandlerFunc
 | |
| }
 | |
| 
 | |
| func (opts ErrorOptions) withDefaults() ErrorOptions {
 | |
| 	if opts.Logf == nil {
 | |
| 		opts.Logf = logger.Discard
 | |
| 	}
 | |
| 	if opts.OnError == nil {
 | |
| 		opts.OnError = WriteHTTPError
 | |
| 	}
 | |
| 	return opts
 | |
| }
 | |
| 
 | |
| func (opts HandlerOptions) errorOptions() ErrorOptions {
 | |
| 	return ErrorOptions{
 | |
| 		OnError: opts.OnError,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ErrorHandlerFunc is called to present a error response.
 | |
| type ErrorHandlerFunc func(http.ResponseWriter, *http.Request, HTTPError)
 | |
| 
 | |
| // OnStartFunc is called before ServeHTTP is called.
 | |
| type OnStartFunc func(*http.Request, AccessLogRecord)
 | |
| 
 | |
| // OnCompletionFunc is called when ServeHTTP is finished and gets
 | |
| // useful data that the implementor can use for metrics.
 | |
| type OnCompletionFunc func(*http.Request, AccessLogRecord)
 | |
| 
 | |
| // ReturnHandlerFunc is an adapter to allow the use of ordinary
 | |
| // functions as ReturnHandlers. If f is a function with the
 | |
| // appropriate signature, ReturnHandlerFunc(f) is a ReturnHandler that
 | |
| // calls f.
 | |
| type ReturnHandlerFunc func(http.ResponseWriter, *http.Request) error
 | |
| 
 | |
| // A Middleware is a function that wraps an http.Handler to extend or modify
 | |
| // its behaviour.
 | |
| //
 | |
| // The implementation of the wrapper is responsible for delegating its input
 | |
| // request to the underlying handler, if appropriate.
 | |
| type Middleware func(h http.Handler) http.Handler
 | |
| 
 | |
| // MiddlewareStack combines multiple middleware into a single middleware for
 | |
| // decorating a [http.Handler]. The first middleware argument will be the first
 | |
| // to process an incoming request, before passing the request onto subsequent
 | |
| // middleware and eventually the wrapped handler.
 | |
| //
 | |
| // For example:
 | |
| //
 | |
| //	MiddlewareStack(A, B)(h).ServeHTTP(w, r)
 | |
| //
 | |
| // calls in sequence:
 | |
| //
 | |
| //	   a.ServeHTTP(w, r)
 | |
| //	-> b.ServeHTTP(w, r)
 | |
| //	-> h.ServeHTTP(w, r)
 | |
| //
 | |
| // (where the lowercase handlers were generated by the uppercase middleware).
 | |
| func MiddlewareStack(mw ...Middleware) Middleware {
 | |
| 	if len(mw) == 1 {
 | |
| 		return mw[0]
 | |
| 	}
 | |
| 	return func(h http.Handler) http.Handler {
 | |
| 		for i := len(mw) - 1; i >= 0; i-- {
 | |
| 			h = mw[i](h)
 | |
| 		}
 | |
| 		return h
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ServeHTTPReturn calls f(w, r).
 | |
| func (f ReturnHandlerFunc) ServeHTTPReturn(w http.ResponseWriter, r *http.Request) error {
 | |
| 	return f(w, r)
 | |
| }
 | |
| 
 | |
| // StdHandler converts a ReturnHandler into a standard http.Handler.
 | |
| // Handled requests are logged using opts.Logf, as are any errors.
 | |
| // Errors are handled as specified by the ReturnHandler interface.
 | |
| // Short-hand for LogHandler(ErrorHandler()).
 | |
| func StdHandler(h ReturnHandler, opts HandlerOptions) http.Handler {
 | |
| 	return LogHandler(ErrorHandler(h, opts.errorOptions()), opts.logOptions())
 | |
| }
 | |
| 
 | |
| // LogHandler returns an http.Handler that logs to opts.Logf.
 | |
| // It logs both successful and failing requests.
 | |
| // The log line includes the first error returned to [ErrorHandler] within.
 | |
| // The outer-most LogHandler(LogHandler(...)) does all of the logging.
 | |
| // Inner LogHandler instance do nothing.
 | |
| // Panics are swallowed and their stack traces are put in the error.
 | |
| func LogHandler(h http.Handler, opts LogOptions) http.Handler {
 | |
| 	return logHandler{h, opts.withDefaults()}
 | |
| }
 | |
| 
 | |
| // ErrorHandler converts a [ReturnHandler] into a standard [http.Handler].
 | |
| // Errors are handled as specified by the [ReturnHandler.ServeHTTPReturn] method.
 | |
| // When wrapped in a [LogHandler], panics are added to the [AccessLogRecord];
 | |
| // otherwise, panics continue up the stack.
 | |
| func ErrorHandler(h ReturnHandler, opts ErrorOptions) http.Handler {
 | |
| 	return errorHandler{h, opts.withDefaults()}
 | |
| }
 | |
| 
 | |
| // errCallback is added to logHandler's request context so that errorHandler can
 | |
| // pass errors back up the stack to logHandler.
 | |
| var errCallback = ctxkey.New[func(HTTPError)]("tailscale.com/tsweb.errCallback", nil)
 | |
| 
 | |
| // logHandler is a http.Handler which logs the HTTP request.
 | |
| // It injects an errCallback for errorHandler to augment the log message with
 | |
| // a specific error.
 | |
| type logHandler struct {
 | |
| 	h    http.Handler
 | |
| 	opts LogOptions
 | |
| }
 | |
| 
 | |
| func (h logHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | |
| 	// If there's already a logHandler up the chain, skip this one.
 | |
| 	ctx := r.Context()
 | |
| 	if errCallback.Has(ctx) {
 | |
| 		h.h.ServeHTTP(w, r)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	msg := AccessLogRecord{
 | |
| 		Time:       h.opts.Now(),
 | |
| 		RemoteAddr: r.RemoteAddr,
 | |
| 		Proto:      r.Proto,
 | |
| 		TLS:        r.TLS != nil,
 | |
| 		Host:       r.Host,
 | |
| 		Method:     r.Method,
 | |
| 		RequestURI: r.URL.RequestURI(),
 | |
| 		UserAgent:  r.UserAgent(),
 | |
| 		Referer:    r.Referer(),
 | |
| 		RequestID:  RequestIDFromContext(r.Context()),
 | |
| 	}
 | |
| 
 | |
| 	if bs := h.opts.BucketedStats; bs != nil && bs.Started != nil && bs.Finished != nil {
 | |
| 		bucket := bs.bucketForRequest(r)
 | |
| 		var startRecorded bool
 | |
| 		switch v := bs.Started.Map.Get(bucket).(type) {
 | |
| 		case *expvar.Int:
 | |
| 			// If we've already seen this bucket for, count it immediately.
 | |
| 			// Otherwise, for newly seen paths, only count retroactively
 | |
| 			// (so started-finished doesn't go negative) so we don't fill
 | |
| 			// this LabelMap up with internet scanning spam.
 | |
| 			v.Add(1)
 | |
| 			startRecorded = true
 | |
| 		}
 | |
| 		defer func() {
 | |
| 			// Only increment metrics for buckets that result in good HTTP statuses
 | |
| 			// or when we know the start was already counted.
 | |
| 			// Otherwise they get full of internet scanning noise. Only filtering 404
 | |
| 			// gets most of the way there but there are also plenty of URLs that are
 | |
| 			// almost right but result in 400s too. Seem easier to just only ignore
 | |
| 			// all 4xx and 5xx.
 | |
| 			if startRecorded {
 | |
| 				bs.Finished.Add(bucket, 1)
 | |
| 			} else if msg.Code < 400 {
 | |
| 				// This is the first non-error request for this bucket,
 | |
| 				// so count it now retroactively.
 | |
| 				bs.Started.Add(bucket, 1)
 | |
| 				bs.Finished.Add(bucket, 1)
 | |
| 			}
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	if fn := h.opts.OnStart; fn != nil {
 | |
| 		fn(r, msg)
 | |
| 	}
 | |
| 
 | |
| 	// Let errorHandler tell us what error it wrote to the client.
 | |
| 	r = r.WithContext(errCallback.WithValue(ctx, func(e HTTPError) {
 | |
| 		// Keep the deepest error.
 | |
| 		if msg.Err != "" {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// Log the error.
 | |
| 		if e.Msg != "" && e.Err != nil {
 | |
| 			msg.Err = e.Msg + ": " + e.Err.Error()
 | |
| 		} else if e.Err != nil {
 | |
| 			msg.Err = e.Err.Error()
 | |
| 		} else if e.Msg != "" {
 | |
| 			msg.Err = e.Msg
 | |
| 		}
 | |
| 
 | |
| 		// We log the code from the loggingResponseWriter, except for
 | |
| 		// cancellation where we override with 499.
 | |
| 		if reqCancelled(r, e.Err) {
 | |
| 			msg.Code = 499
 | |
| 		}
 | |
| 	}))
 | |
| 
 | |
| 	lw := newLogResponseWriter(h.opts.Logf, w, r)
 | |
| 
 | |
| 	defer func() {
 | |
| 		// If the handler panicked then make sure we include that in our error.
 | |
| 		// Panics caught up errorHandler shouldn't appear here, unless the panic
 | |
| 		// originates in one of its callbacks.
 | |
| 		recovered := recover()
 | |
| 		if recovered != nil {
 | |
| 			if msg.Err == "" {
 | |
| 				msg.Err = panic2err(recovered).Error()
 | |
| 			} else {
 | |
| 				msg.Err += "\n\nthen " + panic2err(recovered).Error()
 | |
| 			}
 | |
| 		}
 | |
| 		h.logRequest(r, lw, msg)
 | |
| 	}()
 | |
| 
 | |
| 	h.h.ServeHTTP(lw, r)
 | |
| }
 | |
| 
 | |
| func (h logHandler) logRequest(r *http.Request, lw *loggingResponseWriter, msg AccessLogRecord) {
 | |
| 	// Complete our access log from the loggingResponseWriter.
 | |
| 	msg.Bytes = lw.bytes
 | |
| 	msg.Seconds = h.opts.Now().Sub(msg.Time).Seconds()
 | |
| 	switch {
 | |
| 	case msg.Code != 0:
 | |
| 		// Keep explicit codes from a few particular errors.
 | |
| 	case lw.hijacked:
 | |
| 		// Connection no longer belongs to us, just log that we
 | |
| 		// switched protocols away from HTTP.
 | |
| 		msg.Code = http.StatusSwitchingProtocols
 | |
| 	case lw.code == 0:
 | |
| 		// If the handler didn't write and didn't send a header, that still means 200.
 | |
| 		// (See https://play.golang.org/p/4P7nx_Tap7p)
 | |
| 		msg.Code = 200
 | |
| 	default:
 | |
| 		msg.Code = lw.code
 | |
| 	}
 | |
| 
 | |
| 	// Keep track of the original response code when we've overridden it.
 | |
| 	if lw.code != 0 && msg.Code != lw.code {
 | |
| 		if msg.Err == "" {
 | |
| 			msg.Err = fmt.Sprintf("(original code %d)", lw.code)
 | |
| 		} else {
 | |
| 			msg.Err = fmt.Sprintf("%s (original code %d)", msg.Err, lw.code)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if !h.opts.QuietLogging && !(h.opts.QuietLoggingIfSuccessful && (msg.Code == http.StatusOK || msg.Code == http.StatusNotModified)) {
 | |
| 		h.opts.Logf("%s", msg)
 | |
| 	}
 | |
| 
 | |
| 	if h.opts.OnCompletion != nil {
 | |
| 		h.opts.OnCompletion(r, msg)
 | |
| 	}
 | |
| 
 | |
| 	// Closing metrics.
 | |
| 	if h.opts.StatusCodeCounters != nil {
 | |
| 		h.opts.StatusCodeCounters.Add(responseCodeString(msg.Code/100), 1)
 | |
| 	}
 | |
| 	if h.opts.StatusCodeCountersFull != nil {
 | |
| 		h.opts.StatusCodeCountersFull.Add(responseCodeString(msg.Code), 1)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func responseCodeString(code int) string {
 | |
| 	if v, ok := responseCodeCache.Load(code); ok {
 | |
| 		return v.(string)
 | |
| 	}
 | |
| 
 | |
| 	var ret string
 | |
| 	if code < 10 {
 | |
| 		ret = fmt.Sprintf("%dxx", code)
 | |
| 	} else {
 | |
| 		ret = strconv.Itoa(code)
 | |
| 	}
 | |
| 	responseCodeCache.Store(code, ret)
 | |
| 	return ret
 | |
| }
 | |
| 
 | |
| // responseCodeCache memoizes the string form of HTTP response codes,
 | |
| // so that the hot request-handling codepath doesn't have to allocate
 | |
| // in strconv/fmt for every request.
 | |
| //
 | |
| // Keys are either full HTTP response code ints (200, 404) or "family"
 | |
| // ints representing entire families (e.g. 2 for 2xx codes). Values
 | |
| // are the string form of that code/family.
 | |
| var responseCodeCache sync.Map
 | |
| 
 | |
| // loggingResponseWriter wraps a ResponseWriter and record the HTTP
 | |
| // response code that gets sent, if any.
 | |
| type loggingResponseWriter struct {
 | |
| 	http.ResponseWriter
 | |
| 	ctx      context.Context
 | |
| 	code     int
 | |
| 	bytes    int
 | |
| 	hijacked bool
 | |
| 	logf     logger.Logf
 | |
| }
 | |
| 
 | |
| // newLogResponseWriter returns a loggingResponseWriter which uses's the logger
 | |
| // from r, or falls back to logf. If a nil logger is given, the logs are
 | |
| // discarded.
 | |
| func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Request) *loggingResponseWriter {
 | |
| 	if l, ok := logger.LogfKey.ValueOk(r.Context()); ok && l != nil {
 | |
| 		logf = l
 | |
| 	}
 | |
| 	if logf == nil {
 | |
| 		logf = logger.Discard
 | |
| 	}
 | |
| 	return &loggingResponseWriter{
 | |
| 		ResponseWriter: w,
 | |
| 		ctx:            r.Context(),
 | |
| 		logf:           logf,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // WriteHeader implements [http.ResponseWriter].
 | |
| func (l *loggingResponseWriter) WriteHeader(statusCode int) {
 | |
| 	if l.code != 0 {
 | |
| 		l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode)
 | |
| 		return
 | |
| 	}
 | |
| 	if l.ctx.Err() == nil {
 | |
| 		l.code = statusCode
 | |
| 	}
 | |
| 	l.ResponseWriter.WriteHeader(statusCode)
 | |
| }
 | |
| 
 | |
| // Write implements [http.ResponseWriter].
 | |
| func (l *loggingResponseWriter) Write(bs []byte) (int, error) {
 | |
| 	if l.code == 0 {
 | |
| 		l.code = 200
 | |
| 	}
 | |
| 	n, err := l.ResponseWriter.Write(bs)
 | |
| 	l.bytes += n
 | |
| 	return n, err
 | |
| }
 | |
| 
 | |
| // Hijack implements http.Hijacker. Note that hijacking can still fail
 | |
| // because the wrapped ResponseWriter is not required to implement
 | |
| // Hijacker, as this breaks HTTP/2.
 | |
| func (l *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
 | |
| 	h, ok := l.ResponseWriter.(http.Hijacker)
 | |
| 	if !ok {
 | |
| 		return nil, nil, errors.New("ResponseWriter is not a Hijacker")
 | |
| 	}
 | |
| 	conn, buf, err := h.Hijack()
 | |
| 	if err == nil {
 | |
| 		l.hijacked = true
 | |
| 	}
 | |
| 	return conn, buf, err
 | |
| }
 | |
| 
 | |
| func (l loggingResponseWriter) Flush() {
 | |
| 	f, _ := l.ResponseWriter.(http.Flusher)
 | |
| 	if f == nil {
 | |
| 		l.logf("[unexpected] tried to Flush a ResponseWriter that can't flush")
 | |
| 		return
 | |
| 	}
 | |
| 	f.Flush()
 | |
| }
 | |
| 
 | |
| // errorHandler is an http.Handler that wraps a ReturnHandler to render the
 | |
| // returned errors to the client and pass them back to any logHandlers.
 | |
| type errorHandler struct {
 | |
| 	rh   ReturnHandler
 | |
| 	opts ErrorOptions
 | |
| }
 | |
| 
 | |
| // ServeHTTP implements the http.Handler interface.
 | |
| func (h errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | |
| 	// Keep track of whether a response gets written.
 | |
| 	lw, ok := w.(*loggingResponseWriter)
 | |
| 	if !ok {
 | |
| 		lw = newLogResponseWriter(h.opts.Logf, w, r)
 | |
| 	}
 | |
| 
 | |
| 	var err error
 | |
| 	defer func() {
 | |
| 		// In case the handler panics, we want to recover and continue logging
 | |
| 		// the error before logging it (or re-panicking if we couldn't log).
 | |
| 		rec := recover()
 | |
| 		if rec != nil {
 | |
| 			err = panic2err(rec)
 | |
| 		}
 | |
| 		if err == nil {
 | |
| 			return
 | |
| 		}
 | |
| 		if h.handleError(w, r, lw, err) {
 | |
| 			return
 | |
| 		}
 | |
| 		if rec != nil {
 | |
| 			// If we weren't able to log the panic somewhere, throw it up the
 | |
| 			// stack to someone who can.
 | |
| 			panic(rec)
 | |
| 		}
 | |
| 	}()
 | |
| 	err = h.rh.ServeHTTPReturn(lw, r)
 | |
| }
 | |
| 
 | |
| func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *loggingResponseWriter, err error) bool {
 | |
| 	var logged bool
 | |
| 
 | |
| 	// Extract a presentable, loggable error.
 | |
| 	var hOK bool
 | |
| 	var hErr HTTPError
 | |
| 	if errors.As(err, &hErr) {
 | |
| 		hOK = true
 | |
| 		if hErr.Code == 0 {
 | |
| 			lw.logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr)
 | |
| 			hErr.Code = http.StatusInternalServerError
 | |
| 		}
 | |
| 	} else if v, ok := vizerror.As(err); ok {
 | |
| 		hErr = Error(http.StatusInternalServerError, v.Error(), nil)
 | |
| 	} else if reqCancelled(r, err) {
 | |
| 		// 499 is the Nginx convention meaning "Client Closed Connection".
 | |
| 		if errors.Is(err, context.Canceled) || errors.Is(err, http.ErrAbortHandler) {
 | |
| 			hErr = Error(499, "", err)
 | |
| 		} else {
 | |
| 			hErr = Error(499, "", fmt.Errorf("%w: %w", context.Canceled, err))
 | |
| 		}
 | |
| 	} else {
 | |
| 		// Omit the friendly message so HTTP logs show the bare error that was
 | |
| 		// returned and we know it's not a HTTPError.
 | |
| 		hErr = Error(http.StatusInternalServerError, "", err)
 | |
| 	}
 | |
| 
 | |
| 	// Tell the logger what error we wrote back to the client.
 | |
| 	if pb := errCallback.Value(r.Context()); pb != nil {
 | |
| 		pb(hErr)
 | |
| 		logged = true
 | |
| 	}
 | |
| 
 | |
| 	if r.Context().Err() != nil {
 | |
| 		return logged
 | |
| 	}
 | |
| 
 | |
| 	if lw.code != 0 {
 | |
| 		if hOK && hErr.Code != lw.code {
 | |
| 			lw.logf("[unexpected] handler returned HTTPError %v, but already sent response with code %d", hErr, lw.code)
 | |
| 		}
 | |
| 		return logged
 | |
| 	}
 | |
| 
 | |
| 	// Set a default error message from the status code. Do this after we pass
 | |
| 	// the error back to the logger so that `return errors.New("oh")` logs as
 | |
| 	// `"err": "oh"`, not `"err": "Internal Server Error: oh"`.
 | |
| 	if hErr.Msg == "" {
 | |
| 		switch hErr.Code {
 | |
| 		case 499:
 | |
| 			hErr.Msg = "Client Closed Request"
 | |
| 		default:
 | |
| 			hErr.Msg = http.StatusText(hErr.Code)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// If OnError panics before a response is written, write a bare 500 back.
 | |
| 	// OnError panics are thrown further up the stack.
 | |
| 	defer func() {
 | |
| 		if lw.code == 0 {
 | |
| 			if rec := recover(); rec != nil {
 | |
| 				w.WriteHeader(http.StatusInternalServerError)
 | |
| 				panic(rec)
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	h.opts.OnError(w, r, hErr)
 | |
| 	return logged
 | |
| }
 | |
| 
 | |
| // panic2err converts a recovered value to an error containing the panic stack trace.
 | |
| func panic2err(recovered any) error {
 | |
| 	if recovered == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	if recovered == http.ErrAbortHandler {
 | |
| 		return http.ErrAbortHandler
 | |
| 	}
 | |
| 
 | |
| 	// Even if r is an error, do not wrap it as an error here as
 | |
| 	// that would allow things like panic(vizerror.New("foo"))
 | |
| 	// which is really hard to define the behavior of.
 | |
| 	var stack [10000]byte
 | |
| 	n := runtime.Stack(stack[:], false)
 | |
| 	return &panicError{
 | |
| 		rec:   recovered,
 | |
| 		stack: stack[:n],
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // panicError is an error that contains a panic.
 | |
| type panicError struct {
 | |
| 	rec   any
 | |
| 	stack []byte
 | |
| }
 | |
| 
 | |
| func (e *panicError) Error() string {
 | |
| 	return fmt.Sprintf("panic: %v\n\n%s", e.rec, e.stack)
 | |
| }
 | |
| 
 | |
| func (e *panicError) Unwrap() error {
 | |
| 	err, _ := e.rec.(error)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // reqCancelled returns true if err is http.ErrAbortHandler or r.Context.Err()
 | |
| // is context.Canceled.
 | |
| func reqCancelled(r *http.Request, err error) bool {
 | |
| 	return errors.Is(err, http.ErrAbortHandler) || r.Context().Err() == context.Canceled
 | |
| }
 | |
| 
 | |
| // WriteHTTPError is the default error response formatter.
 | |
| func WriteHTTPError(w http.ResponseWriter, r *http.Request, e HTTPError) {
 | |
| 	// Don't write a response if we've hit a cancellation/abort.
 | |
| 	if r.Context().Err() != nil || errors.Is(e.Err, http.ErrAbortHandler) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// Default headers set by http.Error.
 | |
| 	h := w.Header()
 | |
| 	h.Set("Content-Type", "text/plain; charset=utf-8")
 | |
| 	h.Set("X-Content-Type-Options", "nosniff")
 | |
| 
 | |
| 	// Custom headers from the error.
 | |
| 	for k, vs := range e.Header {
 | |
| 		h[k] = vs
 | |
| 	}
 | |
| 
 | |
| 	// Write the msg back to the user.
 | |
| 	w.WriteHeader(e.Code)
 | |
| 	fmt.Fprint(w, e.Msg)
 | |
| 
 | |
| 	// If it's a plaintext message, add line breaks and RequestID.
 | |
| 	if strings.HasPrefix(h.Get("Content-Type"), "text/plain") {
 | |
| 		io.WriteString(w, "\n")
 | |
| 		if id := RequestIDFromContext(r.Context()); id != "" {
 | |
| 			io.WriteString(w, id.String())
 | |
| 			io.WriteString(w, "\n")
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // HTTPError is an error with embedded HTTP response information.
 | |
| //
 | |
| // It is the error type to be (optionally) used by Handler.ServeHTTPReturn.
 | |
| type HTTPError struct {
 | |
| 	Code   int         // HTTP response code to send to client; 0 means 500
 | |
| 	Msg    string      // Response body to send to client
 | |
| 	Err    error       // Detailed error to log on the server
 | |
| 	Header http.Header // Optional set of HTTP headers to set in the response
 | |
| }
 | |
| 
 | |
| // Error implements the error interface.
 | |
| func (e HTTPError) Error() string { return fmt.Sprintf("httperror{%d, %q, %v}", e.Code, e.Msg, e.Err) }
 | |
| func (e HTTPError) Unwrap() error { return e.Err }
 | |
| 
 | |
| // Error returns an HTTPError containing the given information.
 | |
| func Error(code int, msg string, err error) HTTPError {
 | |
| 	return HTTPError{Code: code, Msg: msg, Err: err}
 | |
| }
 | |
| 
 | |
| // VarzHandler writes expvar values as Prometheus metrics.
 | |
| // TODO: migrate all users to varz.Handler or promvarz.Handler and remove this.
 | |
| func VarzHandler(w http.ResponseWriter, r *http.Request) {
 | |
| 	varz.Handler(w, r)
 | |
| }
 | |
| 
 | |
| // CleanRedirectURL ensures that urlStr is a valid redirect URL to the
 | |
| // current server, or one of allowedHosts. Returns the cleaned URL or
 | |
| // a validation error.
 | |
| func CleanRedirectURL(urlStr string, allowedHosts []string) (*url.URL, error) {
 | |
| 	if urlStr == "" {
 | |
| 		return &url.URL{}, nil
 | |
| 	}
 | |
| 	// In some places, we unfortunately query-escape the redirect URL
 | |
| 	// too many times, and end up needing to redirect to a URL that's
 | |
| 	// still escaped by one level. Try to unescape the input.
 | |
| 	unescaped, err := url.QueryUnescape(urlStr)
 | |
| 	if err == nil && unescaped != urlStr {
 | |
| 		urlStr = unescaped
 | |
| 	}
 | |
| 
 | |
| 	// Go's URL parser and browser URL parsers disagree on the meaning
 | |
| 	// of malformed HTTP URLs. Given the input https:/evil.com, Go
 | |
| 	// parses it as hostname="", path="/evil.com". Browsers parse it
 | |
| 	// as hostname="evil.com", path="". This means that, using
 | |
| 	// malformed URLs, an attacker could trick us into approving of a
 | |
| 	// "local" redirect that in fact sends people elsewhere.
 | |
| 	//
 | |
| 	// This very blunt check enforces that we'll only process
 | |
| 	// redirects that are definitely well-formed URLs.
 | |
| 	//
 | |
| 	// Note that the check for just / also allows URLs of the form
 | |
| 	// "//foo.com/bar", which are scheme-relative redirects. These
 | |
| 	// must be handled with care below when determining whether a
 | |
| 	// redirect is relative to the current host. Notably,
 | |
| 	// url.URL.IsAbs reports // URLs as relative, whereas we want to
 | |
| 	// treat them as absolute redirects and verify the target host.
 | |
| 	if !hasSafeRedirectPrefix(urlStr) {
 | |
| 		return nil, fmt.Errorf("invalid redirect URL %q", urlStr)
 | |
| 	}
 | |
| 
 | |
| 	url, err := url.Parse(urlStr)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("invalid redirect URL %q: %w", urlStr, err)
 | |
| 	}
 | |
| 	// Redirects to self are always allowed. A self redirect must
 | |
| 	// start with url.Path, all prior URL sections must be empty.
 | |
| 	isSelfRedirect := url.Scheme == "" && url.Opaque == "" && url.User == nil && url.Host == ""
 | |
| 	if isSelfRedirect {
 | |
| 		return url, nil
 | |
| 	}
 | |
| 	for _, allowed := range allowedHosts {
 | |
| 		if strings.EqualFold(allowed, url.Hostname()) {
 | |
| 			return url, nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil, fmt.Errorf("disallowed target host %q in redirect URL %q", url.Hostname(), urlStr)
 | |
| }
 | |
| 
 | |
| // hasSafeRedirectPrefix reports whether url starts with a slash, or
 | |
| // one of the case-insensitive strings "http://" or "https://".
 | |
| func hasSafeRedirectPrefix(url string) bool {
 | |
| 	if len(url) >= 1 && url[0] == '/' {
 | |
| 		return true
 | |
| 	}
 | |
| 	const http = "http://"
 | |
| 	if len(url) >= len(http) && strings.EqualFold(url[:len(http)], http) {
 | |
| 		return true
 | |
| 	}
 | |
| 	const https = "https://"
 | |
| 	if len(url) >= len(https) && strings.EqualFold(url[:len(https)], https) {
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // AddBrowserHeaders sets various HTTP security headers for browser-facing endpoints.
 | |
| //
 | |
| // The specific headers:
 | |
| //   - require HTTPS access (HSTS)
 | |
| //   - disallow iframe embedding
 | |
| //   - mitigate MIME confusion attacks
 | |
| //
 | |
| // These headers are based on
 | |
| // https://infosec.mozilla.org/guidelines/web_security
 | |
| func AddBrowserHeaders(w http.ResponseWriter) {
 | |
| 	w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
 | |
| 	w.Header().Set("Content-Security-Policy", "default-src 'self'; frame-ancestors 'none'; form-action 'self'; base-uri 'self'; block-all-mixed-content; object-src 'none'")
 | |
| 	w.Header().Set("X-Frame-Options", "DENY")
 | |
| 	w.Header().Set("X-Content-Type-Options", "nosniff")
 | |
| }
 | |
| 
 | |
| // BrowserHeaderHandler wraps the provided http.Handler with a call to
 | |
| // AddBrowserHeaders.
 | |
| func BrowserHeaderHandler(h http.Handler) http.Handler {
 | |
| 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | |
| 		AddBrowserHeaders(w)
 | |
| 		h.ServeHTTP(w, r)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // BrowserHeaderHandlerFunc wraps the provided http.HandlerFunc with a call to
 | |
| // AddBrowserHeaders.
 | |
| func BrowserHeaderHandlerFunc(h http.HandlerFunc) http.HandlerFunc {
 | |
| 	return func(w http.ResponseWriter, r *http.Request) {
 | |
| 		AddBrowserHeaders(w)
 | |
| 		h.ServeHTTP(w, r)
 | |
| 	}
 | |
| }
 |