mirror of
https://github.com/siderolabs/omni.git
synced 2025-08-09 02:56:59 +02:00
Omni is source-available under BUSL. Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com> Co-Authored-By: Artem Chernyshev <artem.chernyshev@talos-systems.com> Co-Authored-By: Utku Ozdemir <utku.ozdemir@siderolabs.com> Co-Authored-By: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com> Co-Authored-By: Philipp Sauter <philipp.sauter@siderolabs.com> Co-Authored-By: Noel Georgi <git@frezbo.dev> Co-Authored-By: evgeniybryzh <evgeniybryzh@gmail.com> Co-Authored-By: Tim Jones <tim.jones@siderolabs.com> Co-Authored-By: Andrew Rynhard <andrew@rynhard.io> Co-Authored-By: Spencer Smith <spencer.smith@talos-systems.com> Co-Authored-By: Christian Rolland <christian.rolland@siderolabs.com> Co-Authored-By: Gerard de Leeuw <gdeleeuw@leeuwit.nl> Co-Authored-By: Steve Francis <67986293+steverfrancis@users.noreply.github.com> Co-Authored-By: Volodymyr Mazurets <volodymyrmazureets@gmail.com>
221 lines
6.0 KiB
Go
221 lines
6.0 KiB
Go
// Copyright (c) 2024 Sidero Labs, Inc.
|
|
//
|
|
// Use of this software is governed by the Business Source License
|
|
// included in the LICENSE file.
|
|
|
|
package workloadproxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/cosi-project/runtime/pkg/resource"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/siderolabs/omni/internal/pkg/auth"
|
|
"github.com/siderolabs/omni/internal/pkg/config"
|
|
)
|
|
|
|
// ProxyProvider is a provider of HTTP proxies for the exposed services.
|
|
type ProxyProvider interface {
|
|
GetProxy(alias string) (http.Handler, resource.ID, error)
|
|
}
|
|
|
|
// AccessValidator validates workload proxy requests against the given cluster by the given public key ID and its signed & base64'd form.
|
|
type AccessValidator interface {
|
|
ValidateAccess(ctx context.Context, publicKeyID, publicKeyIDSignatureBase64 string, clusterID resource.ID) error
|
|
}
|
|
|
|
// HTTPHandler is an HTTP handler that will proxy matching requests to the workload proxy.
|
|
//
|
|
// It will pass through the requests that don't match.
|
|
type HTTPHandler struct {
|
|
next http.Handler
|
|
logger *zap.Logger
|
|
proxyProvider ProxyProvider
|
|
accessValidator AccessValidator
|
|
mainURL *url.URL
|
|
mainDomain string
|
|
}
|
|
|
|
// NewHTTPHandler creates a new HTTP handler that will proxy requests to the workload proxy.
|
|
func NewHTTPHandler(next http.Handler, proxyProvider ProxyProvider, accessValidator AccessValidator, mainURL *url.URL, logger *zap.Logger) (*HTTPHandler, error) {
|
|
if logger == nil {
|
|
logger = zap.NewNop()
|
|
}
|
|
|
|
if proxyProvider == nil {
|
|
return nil, errors.New("proxy provider is nil")
|
|
}
|
|
|
|
if accessValidator == nil {
|
|
return nil, errors.New("access validator is nil")
|
|
}
|
|
|
|
if mainURL == nil {
|
|
return nil, errors.New("main URL is nil")
|
|
}
|
|
|
|
return &HTTPHandler{
|
|
next: next,
|
|
proxyProvider: proxyProvider,
|
|
accessValidator: accessValidator,
|
|
mainURL: mainURL,
|
|
mainDomain: getMainDomain(mainURL),
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler.
|
|
func (h *HTTPHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
|
if !h.isWorkloadProxyRequest(request) {
|
|
h.next.ServeHTTP(writer, request)
|
|
|
|
return
|
|
}
|
|
|
|
alias := h.parseServiceAliasFromHost(request)
|
|
if alias == "" {
|
|
http.NotFound(writer, request)
|
|
|
|
return
|
|
}
|
|
|
|
proxy, clusterID, err := h.proxyProvider.GetProxy(alias)
|
|
if err != nil {
|
|
h.logger.Warn("failed to get proxy", zap.Error(err), zap.String("alias", alias))
|
|
|
|
http.Error(writer, "failed to get proxy", http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
if proxy == nil {
|
|
h.logger.Debug("proxy is nil", zap.String("alias", alias))
|
|
|
|
http.NotFound(writer, request)
|
|
|
|
return
|
|
}
|
|
|
|
h.checkCookies(writer, request, proxy, clusterID)
|
|
}
|
|
|
|
func (h *HTTPHandler) isWorkloadProxyRequest(request *http.Request) bool {
|
|
host, _, _ := net.SplitHostPort(request.Host) //nolint:errcheck
|
|
|
|
if host == "" {
|
|
host = request.Host
|
|
}
|
|
|
|
return strings.HasPrefix(host, HostPrefix+"-") && strings.HasSuffix(host, "-"+h.mainDomain)
|
|
}
|
|
|
|
func (h *HTTPHandler) checkCookies(writer http.ResponseWriter, request *http.Request, proxy http.Handler, clusterID resource.ID) {
|
|
publicKeyID, publicKeyIDSignatureBase64 := h.getSignatureCookies(request)
|
|
if publicKeyID == "" || publicKeyIDSignatureBase64 == "" {
|
|
h.redirectToLogin(writer, request)
|
|
|
|
return
|
|
}
|
|
|
|
if err := h.accessValidator.ValidateAccess(request.Context(), publicKeyID, publicKeyIDSignatureBase64, clusterID); err != nil {
|
|
h.logger.Warn("failed to validate access", zap.Error(err))
|
|
|
|
forbiddenURL := h.mainURL.JoinPath("/forbidden").String()
|
|
|
|
http.Redirect(writer, request, forbiddenURL, http.StatusSeeOther)
|
|
|
|
return
|
|
}
|
|
|
|
proxy.ServeHTTP(writer, request)
|
|
}
|
|
|
|
// parseServiceAliasFromHost parses the service alias from the request host.
|
|
//
|
|
// The host will have the pattern: p-<alias>-<instance-name>.<main domain>.
|
|
func (h *HTTPHandler) parseServiceAliasFromHost(request *http.Request) (alias string) {
|
|
hostParts := strings.SplitN(request.Host, ".", 2)
|
|
if len(hostParts) == 0 {
|
|
h.logger.Debug("empty proxy service host", zap.String("host", request.Host))
|
|
|
|
return ""
|
|
}
|
|
|
|
proxyServiceHostPrefixParts := strings.SplitN(hostParts[0], "-", 3)
|
|
if len(proxyServiceHostPrefixParts) < 3 {
|
|
h.logger.Debug("invalid proxy service host prefix: wrong number of parts", zap.String("host", request.Host), zap.Strings("parts", proxyServiceHostPrefixParts))
|
|
|
|
return ""
|
|
}
|
|
|
|
if proxyServiceHostPrefixParts[0] != HostPrefix {
|
|
h.logger.Debug("invalid proxy service host prefix: doesn't start with the prefix", zap.String("host", request.Host), zap.Strings("parts", proxyServiceHostPrefixParts))
|
|
|
|
return ""
|
|
}
|
|
|
|
return proxyServiceHostPrefixParts[1]
|
|
}
|
|
|
|
func (h *HTTPHandler) getSignatureCookies(request *http.Request) (publicKeyID string, publicKeyIDSignatureBase64 string) {
|
|
for _, cookie := range request.Cookies() {
|
|
switch cookie.Name {
|
|
case PublicKeyIDCookie:
|
|
publicKeyID = cookie.Value
|
|
case PublicKeyIDSignatureBase64Cookie:
|
|
publicKeyIDSignatureBase64 = cookie.Value
|
|
}
|
|
|
|
if publicKeyID != "" && publicKeyIDSignatureBase64 != "" {
|
|
break
|
|
}
|
|
}
|
|
|
|
return publicKeyID, publicKeyIDSignatureBase64
|
|
}
|
|
|
|
func (h *HTTPHandler) redirectToLogin(writer http.ResponseWriter, request *http.Request) {
|
|
loginURL, err := url.Parse(config.Config.APIURL)
|
|
if err != nil {
|
|
h.logger.Warn("failed to redirect to login", zap.Error(err))
|
|
|
|
http.Error(writer, "failed to redirect to login", http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
reqURL := *request.URL
|
|
reqURL.Scheme = "https"
|
|
reqURL.Host = request.Host
|
|
|
|
if reqURL.Port() == "" && loginURL.Port() != "" {
|
|
reqURL.Host = fmt.Sprintf("%s:%s", request.Host, loginURL.Port())
|
|
}
|
|
|
|
loginURL.Path = "/omni/authenticate"
|
|
q := loginURL.Query()
|
|
q.Set(auth.RedirectQueryParam, reqURL.String())
|
|
q.Set(auth.FlowQueryParam, auth.ProxyAuthFlow)
|
|
|
|
loginURL.RawQuery = q.Encode()
|
|
|
|
http.Redirect(writer, request, loginURL.String(), http.StatusSeeOther)
|
|
}
|
|
|
|
func getMainDomain(url *url.URL) string {
|
|
host, _, _ := net.SplitHostPort(url.Host) //nolint:errcheck
|
|
|
|
if host != "" {
|
|
return host
|
|
}
|
|
|
|
return url.Host
|
|
}
|