omni/internal/backend/saml/saml.go
Utku Ozdemir 0e76483bab
Some checks failed
default / default (push) Has been cancelled
default / e2e-backups (push) Has been cancelled
default / e2e-forced-removal (push) Has been cancelled
default / e2e-omni-upgrade (push) Has been cancelled
default / e2e-scaling (push) Has been cancelled
default / e2e-short (push) Has been cancelled
default / e2e-short-secureboot (push) Has been cancelled
default / e2e-templates (push) Has been cancelled
default / e2e-upgrades (push) Has been cancelled
default / e2e-workload-proxy (push) Has been cancelled
chore: rekres, bump deps, Go, Talos and k8s versions, satisfy linters
- Bump some deps, namely cosi-runtime and Talos machinery.
- Update `auditState` to implement the new methods in COSI's `state.State`.
- Bump default Talos and Kubernetes versions to their latest.
- Rekres, which brings Go 1.24.5. Also update it in go.mod files.
- Fix linter errors coming from new linters.

Signed-off-by: Utku Ozdemir <utku.ozdemir@siderolabs.com>
2025-07-11 18:23:48 +02:00

135 lines
3.5 KiB
Go

// Copyright (c) 2025 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// Package saml contains SAML setup handlers.
package saml
import (
"context"
"errors"
"net/http"
"net/url"
"os"
"github.com/cosi-project/runtime/pkg/state"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"github.com/siderolabs/omni/client/api/omni/specs"
"github.com/siderolabs/omni/internal/backend/logging"
"github.com/siderolabs/omni/internal/backend/monitoring"
"github.com/siderolabs/omni/internal/pkg/config"
)
// NewHandler creates new SAML handler.
func NewHandler(state state.State, cfg *specs.AuthConfigSpec_SAML, logger *zap.Logger) (*samlsp.Middleware, error) {
idpMetadata, err := readMetadata(cfg)
if err != nil {
return nil, err
}
rootURL, err := url.Parse(config.Config.Services.API.URL())
if err != nil {
return nil, err
}
opts := samlsp.Options{
URL: *rootURL,
IDPMetadata: idpMetadata,
LogoutBindings: []string{saml.HTTPPostBinding},
AllowIDPInitiated: true,
}
serviceProvider := samlsp.DefaultServiceProvider(opts)
if cfg.NameIdFormat != "" {
serviceProvider.AuthnNameIDFormat = saml.NameIDFormat(cfg.NameIdFormat)
}
requestTracker := samlsp.DefaultRequestTracker(opts, &serviceProvider)
requestTracker.Codec = &Encoder{}
m := &samlsp.Middleware{
ServiceProvider: serviceProvider,
ResponseBinding: saml.HTTPPostBinding,
OnError: createErrorHandler(logger),
Session: NewSessionProvider(
state,
requestTracker,
logger.With(logging.Component("saml_session")),
),
RequestTracker: requestTracker,
AssertionHandler: samlsp.DefaultAssertionHandler(samlsp.Options{}),
}
return m, nil
}
// RegisterHandlers adds login and logout handlers.
func RegisterHandlers(saml *samlsp.Middleware, mux *http.ServeMux, logger *zap.Logger) {
logger = logger.With(zap.String("handler", "saml"))
login := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
saml.HandleStartAuthFlow(w, r)
})
md := http.HandlerFunc(saml.ServeMetadata)
promLabel := prometheus.Labels{"handler": "saml"}
mux.Handle("/saml/", monitoring.NewHandler(
logging.NewHandler(saml, logger),
promLabel,
))
mux.Handle("/saml/metadata", monitoring.NewHandler(
logging.NewHandler(md, logger),
promLabel,
))
mux.Handle("/login", monitoring.NewHandler(
logging.NewHandler(login, logger),
promLabel,
))
}
func readMetadata(cfg *specs.AuthConfigSpec_SAML) (*saml.EntityDescriptor, error) {
if cfg.Url != "" {
idpMetadataURL, err := url.Parse(cfg.Url)
if err != nil {
return nil, err
}
return samlsp.FetchMetadata(context.Background(), http.DefaultClient,
*idpMetadataURL)
}
data, err := os.ReadFile(cfg.Metadata)
if err != nil {
return nil, err
}
return samlsp.ParseMetadata(data)
}
func createErrorHandler(logger *zap.Logger) func(http.ResponseWriter, *http.Request, error) {
logger = logger.With(logging.Component("saml"))
return func(w http.ResponseWriter, r *http.Request, err error) {
var invalidSAML *saml.InvalidResponseError
if errors.As(err, &invalidSAML) {
logger.Warn("received invalid saml response",
zap.String("response", invalidSAML.Response),
zap.Time("now", invalidSAML.Now),
zap.Error(invalidSAML.PrivateErr),
)
} else {
logger.Error("saml error", zap.Error(err))
}
http.Redirect(w, r, "/forbidden", http.StatusSeeOther)
}
}