vault/builtin/audit/socket/backend.go
Kuba Wieczorek 2047ce7527
[VAULT-22480] Add audit fallback device (#24583)
Co-authored-by: Peter Wilson <peter.wilson@hashicorp.com>
2024-01-08 13:57:43 +00:00

469 lines
12 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package socket
import (
"bytes"
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/eventlogger"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/internal/observability/event"
"github.com/hashicorp/vault/sdk/helper/salt"
"github.com/hashicorp/vault/sdk/logical"
)
var _ audit.Backend = (*Backend)(nil)
// Backend is the audit backend for the socket audit transport.
type Backend struct {
sync.Mutex
address string
connection net.Conn
fallback bool
formatter *audit.EntryFormatterWriter
formatConfig audit.FormatterConfig
name string
nodeIDList []eventlogger.NodeID
nodeMap map[eventlogger.NodeID]eventlogger.Node
salt *salt.Salt
saltConfig *salt.Config
saltMutex sync.RWMutex
saltView logical.Storage
socketType string
writeDuration time.Duration
}
func Factory(_ context.Context, conf *audit.BackendConfig, useEventLogger bool, headersConfig audit.HeaderFormatter) (audit.Backend, error) {
const op = "socket.Factory"
if conf.SaltConfig == nil {
return nil, fmt.Errorf("%s: nil salt config", op)
}
if conf.SaltView == nil {
return nil, fmt.Errorf("%s: nil salt view", op)
}
address, ok := conf.Config["address"]
if !ok {
return nil, fmt.Errorf("%s: address is required", op)
}
socketType, ok := conf.Config["socket_type"]
if !ok {
socketType = "tcp"
}
writeDeadline, ok := conf.Config["write_timeout"]
if !ok {
writeDeadline = "2s"
}
writeDuration, err := parseutil.ParseDurationSecond(writeDeadline)
if err != nil {
return nil, fmt.Errorf("%s: failed to parse 'write_timeout': %w", op, err)
}
// The config options 'fallback' and 'filter' are mutually exclusive, a fallback
// device catches everything, so it cannot be allowed to filter.
var fallback bool
if fallbackRaw, ok := conf.Config["fallback"]; ok {
fallback, err = parseutil.ParseBool(fallbackRaw)
if err != nil {
return nil, fmt.Errorf("%s: unable to parse 'fallback': %w", op, err)
}
}
if _, ok := conf.Config["filter"]; ok && fallback {
return nil, fmt.Errorf("%s: cannot configure a fallback device with a filter: %w", op, event.ErrInvalidParameter)
}
cfg, err := formatterConfig(conf.Config)
if err != nil {
return nil, fmt.Errorf("%s: failed to create formatter config: %w", op, err)
}
b := &Backend{
fallback: fallback,
address: address,
formatConfig: cfg,
name: conf.MountPath,
saltConfig: conf.SaltConfig,
saltView: conf.SaltView,
socketType: socketType,
writeDuration: writeDuration,
}
// Configure the formatter for either case.
f, err := audit.NewEntryFormatter(cfg, b, audit.WithHeaderFormatter(headersConfig))
if err != nil {
return nil, fmt.Errorf("%s: error creating formatter: %w", op, err)
}
var w audit.Writer
switch b.formatConfig.RequiredFormat {
case audit.JSONFormat:
w = &audit.JSONWriter{Prefix: conf.Config["prefix"]}
case audit.JSONxFormat:
w = &audit.JSONxWriter{Prefix: conf.Config["prefix"]}
}
fw, err := audit.NewEntryFormatterWriter(b.formatConfig, f, w)
if err != nil {
return nil, fmt.Errorf("%s: error creating formatter writer: %w", op, err)
}
b.formatter = fw
if useEventLogger {
b.nodeIDList = []eventlogger.NodeID{}
b.nodeMap = make(map[eventlogger.NodeID]eventlogger.Node)
err := b.configureFilterNode(conf.Config["filter"])
if err != nil {
return nil, fmt.Errorf("%s: error configuring filter node: %w", op, err)
}
opts := []audit.Option{
audit.WithHeaderFormatter(headersConfig),
}
err = b.configureFormatterNode(cfg, opts...)
if err != nil {
return nil, fmt.Errorf("%s: error configuring formatter node: %w", op, err)
}
sinkOpts := []event.Option{
event.WithSocketType(socketType),
event.WithMaxDuration(writeDeadline),
}
err = b.configureSinkNode(conf.MountPath, address, cfg.RequiredFormat.String(), sinkOpts...)
if err != nil {
return nil, fmt.Errorf("%s: error configuring sink node: %w", op, err)
}
}
return b, nil
}
// Deprecated: Use eventlogger.
func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error {
var buf bytes.Buffer
if err := b.formatter.FormatAndWriteRequest(ctx, &buf, in); err != nil {
return err
}
b.Lock()
defer b.Unlock()
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(ctx, buf.Bytes())
}
}
return err
}
// Deprecated: Use eventlogger.
func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error {
var buf bytes.Buffer
if err := b.formatter.FormatAndWriteResponse(ctx, &buf, in); err != nil {
return err
}
b.Lock()
defer b.Unlock()
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(ctx, buf.Bytes())
}
}
return err
}
func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
// Event logger behavior - manually Process each node
if len(b.nodeIDList) > 0 {
return audit.ProcessManual(ctx, in, b.nodeIDList, b.nodeMap)
}
// Old behavior
var buf bytes.Buffer
temporaryFormatter, err := audit.NewTemporaryFormatter(config["format"], config["prefix"])
if err != nil {
return err
}
if err = temporaryFormatter.FormatAndWriteRequest(ctx, &buf, in); err != nil {
return err
}
b.Lock()
defer b.Unlock()
err = b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(ctx, buf.Bytes())
}
}
return err
}
// Deprecated: Use eventlogger.
func (b *Backend) write(ctx context.Context, buf []byte) error {
if b.connection == nil {
if err := b.reconnect(ctx); err != nil {
return err
}
}
err := b.connection.SetWriteDeadline(time.Now().Add(b.writeDuration))
if err != nil {
return err
}
_, err = b.connection.Write(buf)
if err != nil {
return err
}
return nil
}
// Deprecated: Use eventlogger.
func (b *Backend) reconnect(ctx context.Context) error {
if b.connection != nil {
b.connection.Close()
b.connection = nil
}
timeoutContext, cancel := context.WithTimeout(ctx, b.writeDuration)
defer cancel()
dialer := net.Dialer{}
conn, err := dialer.DialContext(timeoutContext, b.socketType, b.address)
if err != nil {
return err
}
b.connection = conn
return nil
}
func (b *Backend) Reload(ctx context.Context) error {
b.Lock()
defer b.Unlock()
err := b.reconnect(ctx)
return err
}
func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) {
b.saltMutex.RLock()
if b.salt != nil {
defer b.saltMutex.RUnlock()
return b.salt, nil
}
b.saltMutex.RUnlock()
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
if b.salt != nil {
return b.salt, nil
}
s, err := salt.NewSalt(ctx, b.saltView, b.saltConfig)
if err != nil {
return nil, err
}
b.salt = s
return s, nil
}
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil
}
// formatterConfig creates the configuration required by a formatter node using
// the config map supplied to the factory.
func formatterConfig(config map[string]string) (audit.FormatterConfig, error) {
const op = "socket.formatterConfig"
var cfgOpts []audit.Option
if format, ok := config["format"]; ok {
cfgOpts = append(cfgOpts, audit.WithFormat(format))
}
// Check if hashing of accessor is disabled
if hmacAccessorRaw, ok := config["hmac_accessor"]; ok {
v, err := strconv.ParseBool(hmacAccessorRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'hmac_accessor': %w", op, err)
}
cfgOpts = append(cfgOpts, audit.WithHMACAccessor(v))
}
// Check if raw logging is enabled
if raw, ok := config["log_raw"]; ok {
v, err := strconv.ParseBool(raw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'log_raw': %w", op, err)
}
cfgOpts = append(cfgOpts, audit.WithRaw(v))
}
if elideListResponsesRaw, ok := config["elide_list_responses"]; ok {
v, err := strconv.ParseBool(elideListResponsesRaw)
if err != nil {
return audit.FormatterConfig{}, fmt.Errorf("%s: unable to parse 'elide_list_responses': %w", op, err)
}
cfgOpts = append(cfgOpts, audit.WithElision(v))
}
return audit.NewFormatterConfig(cfgOpts...)
}
// configureFilterNode is used to configure a filter node and associated ID on the Backend.
func (b *Backend) configureFilterNode(filter string) error {
const op = "socket.(Backend).configureFilterNode"
filter = strings.TrimSpace(filter)
if filter == "" {
return nil
}
filterNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("%s: error generating random NodeID for filter node: %w", op, err)
}
filterNode, err := audit.NewEntryFilter(filter)
if err != nil {
return fmt.Errorf("%s: error creating filter node: %w", op, err)
}
b.nodeIDList = append(b.nodeIDList, filterNodeID)
b.nodeMap[filterNodeID] = filterNode
return nil
}
// configureFormatterNode is used to configure a formatter node and associated ID on the Backend.
func (b *Backend) configureFormatterNode(formatConfig audit.FormatterConfig, opts ...audit.Option) error {
const op = "socket.(Backend).configureFormatterNode"
formatterNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("%s: error generating random NodeID for formatter node: %w", op, err)
}
formatterNode, err := audit.NewEntryFormatter(formatConfig, b, opts...)
if err != nil {
return fmt.Errorf("%s: error creating formatter: %w", op, err)
}
b.nodeIDList = append(b.nodeIDList, formatterNodeID)
b.nodeMap[formatterNodeID] = formatterNode
return nil
}
// configureSinkNode is used to configure a sink node and associated ID on the Backend.
func (b *Backend) configureSinkNode(name string, address string, format string, opts ...event.Option) error {
const op = "socket.(Backend).configureSinkNode"
name = strings.TrimSpace(name)
if name == "" {
return fmt.Errorf("%s: name is required: %w", op, event.ErrInvalidParameter)
}
address = strings.TrimSpace(address)
if address == "" {
return fmt.Errorf("%s: address is required: %w", op, event.ErrInvalidParameter)
}
format = strings.TrimSpace(format)
if format == "" {
return fmt.Errorf("%s: format is required: %w", op, event.ErrInvalidParameter)
}
sinkNodeID, err := event.GenerateNodeID()
if err != nil {
return fmt.Errorf("%s: error generating random NodeID for sink node: %w", op, err)
}
n, err := event.NewSocketSink(address, format, opts...)
if err != nil {
return fmt.Errorf("%s: error creating socket sink node: %w", op, err)
}
sinkNode := &audit.SinkWrapper{Name: name, Sink: n}
b.nodeIDList = append(b.nodeIDList, sinkNodeID)
b.nodeMap[sinkNodeID] = sinkNode
return nil
}
// Name for this backend, this would ideally correspond to the mount path for the audit device.
func (b *Backend) Name() string {
return b.name
}
// Nodes returns the nodes which should be used by the event framework to process audit entries.
func (b *Backend) Nodes() map[eventlogger.NodeID]eventlogger.Node {
return b.nodeMap
}
// NodeIDs returns the IDs of the nodes, in the order they are required.
func (b *Backend) NodeIDs() []eventlogger.NodeID {
return b.nodeIDList
}
// EventType returns the event type for the backend.
func (b *Backend) EventType() eventlogger.EventType {
return eventlogger.EventType(event.AuditType.String())
}
// HasFiltering determines if the first node for the pipeline is an eventlogger.NodeTypeFilter.
func (b *Backend) HasFiltering() bool {
return len(b.nodeIDList) > 0 && b.nodeMap[b.nodeIDList[0]].Type() == eventlogger.NodeTypeFilter
}
// IsFallback can be used to determine if this audit backend device is intended to
// be used as a fallback to catch all events that are not written when only using
// filtered pipelines.
func (b *Backend) IsFallback() bool {
return b.fallback
}