Merge branch 'opensource-master' into struct-tags

This commit is contained in:
Becca Petrin 2018-04-11 13:04:08 -07:00
commit c588d02282
10 changed files with 201 additions and 121 deletions

View File

@ -5,18 +5,71 @@ import (
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"strconv"
"testing" "testing"
"time" "time"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing" logicaltest "github.com/hashicorp/vault/logical/testing"
dockertest "gopkg.in/ory-am/dockertest.v3"
) )
const ( const (
testSysTTL = time.Hour * 10 testSysTTL = time.Hour * 10
testSysMaxTTL = time.Hour * 20 testSysMaxTTL = time.Hour * 20
envRadiusRadiusHost = "RADIUS_HOST"
envRadiusPort = "RADIUS_PORT"
envRadiusSecret = "RADIUS_SECRET"
envRadiusUsername = "RADIUS_USERNAME"
envRadiusUserPass = "RADIUS_USERPASS"
) )
func prepareRadiusTestContainer(t *testing.T) (func(), string, int) {
if os.Getenv(envRadiusRadiusHost) != "" {
port, _ := strconv.Atoi(os.Getenv(envRadiusPort))
return func() {}, os.Getenv(envRadiusRadiusHost), port
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
runOpts := &dockertest.RunOptions{
Repository: "jumanjiman/radiusd",
Cmd: []string{"-f", "-l", "stdout"},
ExposedPorts: []string{"1812/udp"},
Tag: "latest",
}
resource, err := pool.RunWithOptions(runOpts)
if err != nil {
t.Fatalf("Could not start local radius docker container: %s", err)
}
cleanup := func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local container: %s", err)
}
}
port, _ := strconv.Atoi(resource.GetPort("1812/udp"))
address := fmt.Sprintf("127.0.0.1")
// exponential backoff-retry
if err = pool.Retry(func() error {
// There's no straightfoward way to check the state, but the server starts
// up quick so a 2 second sleep should be enough.
time.Sleep(2 * time.Second)
return nil
}); err != nil {
cleanup()
t.Fatalf("Could not connect to radius docker container: %s", err)
}
return cleanup, address, port
}
func TestBackend_Config(t *testing.T) { func TestBackend_Config(t *testing.T) {
b, err := Factory(context.Background(), &logical.BackendConfig{ b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil, Logger: nil,
@ -29,43 +82,43 @@ func TestBackend_Config(t *testing.T) {
t.Fatalf("Unable to create backend: %s", err) t.Fatalf("Unable to create backend: %s", err)
} }
config_data_basic := map[string]interface{}{ configDataBasic := map[string]interface{}{
"host": "test.radius.hostname.com", "host": "test.radius.hostname.com",
"secret": "test-secret", "secret": "test-secret",
} }
config_data_missingrequired := map[string]interface{}{ configDataMissingRequired := map[string]interface{}{
"host": "test.radius.hostname.com", "host": "test.radius.hostname.com",
} }
config_data_invalidport := map[string]interface{}{ configDataEmptyPort := map[string]interface{}{
"host": "test.radius.hostname.com",
"port": "",
"secret": "test-secret",
}
configDataInvalidPort := map[string]interface{}{
"host": "test.radius.hostname.com", "host": "test.radius.hostname.com",
"port": "notnumeric", "port": "notnumeric",
"secret": "test-secret", "secret": "test-secret",
} }
config_data_invalidbool := map[string]interface{}{ configDataInvalidBool := map[string]interface{}{
"host": "test.radius.hostname.com", "host": "test.radius.hostname.com",
"secret": "test-secret", "secret": "test-secret",
"unregistered_user_policies": "test", "unregistered_user_policies": "test",
} }
config_data_emptyport := map[string]interface{}{
"host": "test.radius.hostname.com",
"port": "",
"secret": "test-secret",
}
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
AcceptanceTest: false, AcceptanceTest: false,
// PreCheck: func() { testAccPreCheck(t) }, // PreCheck: func() { testAccPreCheck(t) },
Backend: b, Backend: b,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
testConfigWrite(t, config_data_basic, false), testConfigWrite(t, configDataBasic, false),
testConfigWrite(t, config_data_emptyport, true), testConfigWrite(t, configDataMissingRequired, true),
testConfigWrite(t, config_data_invalidport, true), testConfigWrite(t, configDataEmptyPort, true),
testConfigWrite(t, config_data_invalidbool, true), testConfigWrite(t, configDataInvalidPort, true),
testConfigWrite(t, config_data_missingrequired, true), testConfigWrite(t, configDataInvalidBool, true),
}, },
}) })
} }
@ -93,7 +146,6 @@ func TestBackend_users(t *testing.T) {
} }
func TestBackend_acceptance(t *testing.T) { func TestBackend_acceptance(t *testing.T) {
if os.Getenv(logicaltest.TestEnvVar) == "" { if os.Getenv(logicaltest.TestEnvVar) == "" {
t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar)) t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar))
return return
@ -110,10 +162,29 @@ func TestBackend_acceptance(t *testing.T) {
t.Fatalf("Unable to create backend: %s", err) t.Fatalf("Unable to create backend: %s", err)
} }
cleanup, host, port := prepareRadiusTestContainer(t)
defer cleanup()
// These defaults are specific to the jumanjiman/radiusd docker image
username := os.Getenv(envRadiusUsername)
if username == "" {
username = "test"
}
password := os.Getenv(envRadiusUserPass)
if password == "" {
password = "test"
}
secret := os.Getenv(envRadiusSecret)
if len(secret) == 0 {
secret = "testing123"
}
configDataAcceptanceAllowUnreg := map[string]interface{}{ configDataAcceptanceAllowUnreg := map[string]interface{}{
"host": os.Getenv("RADIUS_HOST"), "host": host,
"port": os.Getenv("RADIUS_PORT"), "port": strconv.Itoa(port),
"secret": os.Getenv("RADIUS_SECRET"), "secret": secret,
"unregistered_user_policies": "policy1,policy2", "unregistered_user_policies": "policy1,policy2",
} }
if configDataAcceptanceAllowUnreg["port"] == "" { if configDataAcceptanceAllowUnreg["port"] == "" {
@ -121,9 +192,9 @@ func TestBackend_acceptance(t *testing.T) {
} }
configDataAcceptanceNoAllowUnreg := map[string]interface{}{ configDataAcceptanceNoAllowUnreg := map[string]interface{}{
"host": os.Getenv("RADIUS_HOST"), "host": host,
"port": os.Getenv("RADIUS_PORT"), "port": strconv.Itoa(port),
"secret": os.Getenv("RADIUS_SECRET"), "secret": secret,
"unregistered_user_policies": "", "unregistered_user_policies": "",
} }
if configDataAcceptanceNoAllowUnreg["port"] == "" { if configDataAcceptanceNoAllowUnreg["port"] == "" {
@ -131,18 +202,16 @@ func TestBackend_acceptance(t *testing.T) {
} }
dataRealpassword := map[string]interface{}{ dataRealpassword := map[string]interface{}{
"password": os.Getenv("RADIUS_USERPASS"), "password": password,
} }
dataWrongpassword := map[string]interface{}{ dataWrongpassword := map[string]interface{}{
"password": "wrongpassword", "password": "wrongpassword",
} }
username := os.Getenv("RADIUS_USERNAME")
logicaltest.Test(t, logicaltest.TestCase{ logicaltest.Test(t, logicaltest.TestCase{
Backend: b, Backend: b,
PreCheck: func() { testAccPreCheck(t) }, PreCheck: testAccPreCheck(t, host, port),
AcceptanceTest: true, AcceptanceTest: true,
Steps: []logicaltest.TestStep{ Steps: []logicaltest.TestStep{
// Login with valid but unknown user will fail because unregistered_user_policies is emtpy // Login with valid but unknown user will fail because unregistered_user_policies is emtpy
@ -172,21 +241,15 @@ func TestBackend_acceptance(t *testing.T) {
}) })
} }
func testAccPreCheck(t *testing.T) { func testAccPreCheck(t *testing.T, host string, port int) func() {
if v := os.Getenv("RADIUS_HOST"); v == "" { return func() {
t.Fatal("RADIUS_HOST must be set for acceptance tests") if host == "" {
} t.Fatal("Host must be set for acceptance tests")
}
if v := os.Getenv("RADIUS_USERNAME"); v == "" { if port == 0 {
t.Fatal("RADIUS_USERNAME must be set for acceptance tests") t.Fatal("Port must be non-zero for acceptance tests")
} }
if v := os.Getenv("RADIUS_USERPASS"); v == "" {
t.Fatal("RADIUS_USERPASS must be set for acceptance tests")
}
if v := os.Getenv("RADIUS_SECRET"); v == "" {
t.Fatal("RADIUS_SECRET must be set for acceptance tests")
} }
} }
@ -249,7 +312,7 @@ func testAccUserLoginPolicy(t *testing.T, user string, data map[string]interface
Operation: logical.UpdateOperation, Operation: logical.UpdateOperation,
Path: "login/" + user, Path: "login/" + user,
Data: data, Data: data,
ErrorOk: false, ErrorOk: expectError,
Unauthenticated: true, Unauthenticated: true,
//Check: logicaltest.TestCheckAuth(policies), //Check: logicaltest.TestCheckAuth(policies),
Check: func(resp *logical.Response) error { Check: func(resp *logical.Response) error {

View File

@ -50,7 +50,7 @@ func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConf
} }
username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID) username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID)
username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().UTC().Unix())) username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix()))
if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { if scp.UsernameLen > 0 && len(username) > scp.UsernameLen {
username = username[:scp.UsernameLen] username = username[:scp.UsernameLen]
} }

View File

@ -134,7 +134,7 @@ func (c *Core) enableCredential(ctx context.Context, entry *MountEntry) error {
// Update the auth table // Update the auth table
newTable := c.auth.shallowClone() newTable := c.auth.shallowClone()
newTable.Entries = append(newTable.Entries, entry) newTable.Entries = append(newTable.Entries, entry)
if err := c.persistAuth(ctx, newTable, entry.Local); err != nil { if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil {
return errors.New("failed to update auth table") return errors.New("failed to update auth table")
} }
@ -235,7 +235,7 @@ func (c *Core) removeCredEntry(ctx context.Context, path string) error {
} }
// Update the auth table // Update the auth table
if err := c.persistAuth(ctx, newTable, entry.Local); err != nil { if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil {
return errors.New("failed to update auth table") return errors.New("failed to update auth table")
} }
@ -281,7 +281,7 @@ func (c *Core) taintCredEntry(ctx context.Context, path string) error {
} }
// Update the auth table // Update the auth table
if err := c.persistAuth(ctx, c.auth, entry.Local); err != nil { if err := c.persistAuth(ctx, c.auth, &entry.Local); err != nil {
return errors.New("failed to update auth table") return errors.New("failed to update auth table")
} }
@ -369,7 +369,7 @@ func (c *Core) loadCredentials(ctx context.Context) error {
return nil return nil
} }
if err := c.persistAuth(ctx, c.auth, false); err != nil { if err := c.persistAuth(ctx, c.auth, nil); err != nil {
c.logger.Error("failed to persist auth table", "error", err) c.logger.Error("failed to persist auth table", "error", err)
return errLoadAuthFailed return errLoadAuthFailed
} }
@ -377,7 +377,7 @@ func (c *Core) loadCredentials(ctx context.Context) error {
} }
// persistAuth is used to persist the auth table after modification // persistAuth is used to persist the auth table after modification
func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly bool) error { func (c *Core) persistAuth(ctx context.Context, table *MountTable, local *bool) error {
if table.Type != credentialTableType { if table.Type != credentialTableType {
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", credentialTableType) c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", credentialTableType)
return fmt.Errorf("invalid table type given, not persisting") return fmt.Errorf("invalid table type given, not persisting")
@ -406,45 +406,49 @@ func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly boo
} }
} }
if !localOnly { writeTable := func(mt *MountTable, path string) error {
// Marshal the table // Encode the mount table into JSON and compress it (lzw).
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil) compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil)
if err != nil { if err != nil {
c.logger.Error("failed to encode and/or compress auth table", "error", err) c.logger.Error("failed to encode or compress auth mount table", "error", err)
return err return err
} }
// Create an entry // Create an entry
entry := &Entry{ entry := &Entry{
Key: coreAuthConfigPath, Key: path,
Value: compressedBytes, Value: compressedBytes,
} }
// Write to the physical backend // Write to the physical backend
if err := c.barrier.Put(ctx, entry); err != nil { if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist auth table", "error", err) c.logger.Error("failed to persist auth mount table", "error", err)
return err return err
} }
return nil
} }
// Repeat with local auth var err error
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localAuth, nil) switch {
if err != nil { case local == nil:
c.logger.Error("failed to encode and/or compress local auth table", "error", err) // Write non-local mounts
return err err := writeTable(nonLocalAuth, coreAuthConfigPath)
if err != nil {
return err
}
// Write local mounts
err = writeTable(localAuth, coreLocalAuthConfigPath)
if err != nil {
return err
}
case *local:
err = writeTable(localAuth, coreLocalAuthConfigPath)
default:
err = writeTable(nonLocalAuth, coreAuthConfigPath)
} }
entry := &Entry{ return err
Key: coreLocalAuthConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist local auth table", "error", err)
return err
}
return nil
} }
// setupCredentials is invoked after we've loaded the auth table to // setupCredentials is invoked after we've loaded the auth table to
@ -520,7 +524,7 @@ func (c *Core) setupCredentials(ctx context.Context) error {
} }
if persistNeeded { if persistNeeded {
return c.persistAuth(ctx, c.auth, false) return c.persistAuth(ctx, c.auth, nil)
} }
return nil return nil

View File

@ -164,7 +164,7 @@ func TestCore_EnableCredential_Local(t *testing.T) {
} }
c.auth.Entries[1].Local = true c.auth.Entries[1].Local = true
if err := c.persistAuth(context.Background(), c.auth, false); err != nil { if err := c.persistAuth(context.Background(), c.auth, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -1988,9 +1988,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error var err error
switch { switch {
case strings.HasPrefix(path, "auth/"): case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
} }
if err != nil { if err != nil {
mountEntry.Description = oldDesc mountEntry.Description = oldDesc
@ -2011,9 +2011,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error var err error
switch { switch {
case strings.HasPrefix(path, "auth/"): case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
} }
if err != nil { if err != nil {
mountEntry.Config.AuditNonHMACRequestKeys = oldVal mountEntry.Config.AuditNonHMACRequestKeys = oldVal
@ -2037,9 +2037,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error var err error
switch { switch {
case strings.HasPrefix(path, "auth/"): case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
} }
if err != nil { if err != nil {
mountEntry.Config.AuditNonHMACResponseKeys = oldVal mountEntry.Config.AuditNonHMACResponseKeys = oldVal
@ -2068,9 +2068,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error var err error
switch { switch {
case strings.HasPrefix(path, "auth/"): case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
} }
if err != nil { if err != nil {
mountEntry.Config.ListingVisibility = oldVal mountEntry.Config.ListingVisibility = oldVal
@ -2092,9 +2092,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
var err error var err error
switch { switch {
case strings.HasPrefix(path, "auth/"): case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
} }
if err != nil { if err != nil {
mountEntry.Config.PassthroughRequestHeaders = oldVal mountEntry.Config.PassthroughRequestHeaders = oldVal
@ -2154,9 +2154,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string,
// Update the mount table // Update the mount table
switch { switch {
case strings.HasPrefix(path, "auth/"): case strings.HasPrefix(path, "auth/"):
err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local)
} }
if err != nil { if err != nil {
mountEntry.Options = oldVal mountEntry.Options = oldVal

View File

@ -37,9 +37,9 @@ func (b *SystemBackend) tuneMountTTLs(ctx context.Context, path string, me *Moun
var err error var err error
switch { switch {
case strings.HasPrefix(path, credentialRoutePrefix): case strings.HasPrefix(path, credentialRoutePrefix):
err = b.Core.persistAuth(ctx, b.Core.auth, me.Local) err = b.Core.persistAuth(ctx, b.Core.auth, &me.Local)
default: default:
err = b.Core.persistMounts(ctx, b.Core.mounts, me.Local) err = b.Core.persistMounts(ctx, b.Core.mounts, &me.Local)
} }
if err != nil { if err != nil {
me.Config.MaxLeaseTTL = origMax me.Config.MaxLeaseTTL = origMax

View File

@ -336,7 +336,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry) error {
newTable := c.mounts.shallowClone() newTable := c.mounts.shallowClone()
newTable.Entries = append(newTable.Entries, entry) newTable.Entries = append(newTable.Entries, entry)
if err := c.persistMounts(ctx, newTable, entry.Local); err != nil { if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil {
c.logger.Error("failed to update mount table", "error", err) c.logger.Error("failed to update mount table", "error", err)
return logical.CodedError(500, "failed to update mount table") return logical.CodedError(500, "failed to update mount table")
} }
@ -457,7 +457,7 @@ func (c *Core) removeMountEntry(ctx context.Context, path string) error {
} }
// Update the mount table // Update the mount table
if err := c.persistMounts(ctx, newTable, entry.Local); err != nil { if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil {
c.logger.Error("failed to remove entry from mounts table", "error", err) c.logger.Error("failed to remove entry from mounts table", "error", err)
return logical.CodedError(500, "failed to remove entry from mounts table") return logical.CodedError(500, "failed to remove entry from mounts table")
} }
@ -480,7 +480,7 @@ func (c *Core) taintMountEntry(ctx context.Context, path string) error {
} }
// Update the mount table // Update the mount table
if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil { if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil {
c.logger.Error("failed to taint entry in mounts table", "error", err) c.logger.Error("failed to taint entry in mounts table", "error", err)
return logical.CodedError(500, "failed to taint entry in mounts table") return logical.CodedError(500, "failed to taint entry in mounts table")
} }
@ -571,7 +571,7 @@ func (c *Core) remount(ctx context.Context, src, dst string) error {
} }
// Update the mount table // Update the mount table
if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil { if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil {
entry.Path = src entry.Path = src
entry.Tainted = true entry.Tainted = true
c.mountsLock.Unlock() c.mountsLock.Unlock()
@ -710,7 +710,7 @@ func (c *Core) loadMounts(ctx context.Context) error {
return nil return nil
} }
if err := c.persistMounts(ctx, c.mounts, false); err != nil { if err := c.persistMounts(ctx, c.mounts, nil); err != nil {
c.logger.Error("failed to persist mount table", "error", err) c.logger.Error("failed to persist mount table", "error", err)
return errLoadMountsFailed return errLoadMountsFailed
} }
@ -718,7 +718,7 @@ func (c *Core) loadMounts(ctx context.Context) error {
} }
// persistMounts is used to persist the mount table after modification // persistMounts is used to persist the mount table after modification
func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly bool) error { func (c *Core) persistMounts(ctx context.Context, table *MountTable, local *bool) error {
if table.Type != mountTableType { if table.Type != mountTableType {
c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", mountTableType) c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", mountTableType)
return fmt.Errorf("invalid table type given, not persisting") return fmt.Errorf("invalid table type given, not persisting")
@ -747,17 +747,17 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b
} }
} }
if !localOnly { writeTable := func(mt *MountTable, path string) error {
// Encode the mount table into JSON and compress it (lzw). // Encode the mount table into JSON and compress it (lzw).
compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil) compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil)
if err != nil { if err != nil {
c.logger.Error("failed to encode and/or compress the mount table", "error", err) c.logger.Error("failed to encode or compress mount table", "error", err)
return err return err
} }
// Create an entry // Create an entry
entry := &Entry{ entry := &Entry{
Key: coreMountConfigPath, Key: path,
Value: compressedBytes, Value: compressedBytes,
} }
@ -766,26 +766,33 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b
c.logger.Error("failed to persist mount table", "error", err) c.logger.Error("failed to persist mount table", "error", err)
return err return err
} }
return nil
} }
// Repeat with local mounts var err error
compressedBytes, err := jsonutil.EncodeJSONAndCompress(localMounts, nil) switch {
if err != nil { case local == nil:
c.logger.Error("failed to encode and/or compress the local mount table", "error", err) // Write non-local mounts
return err err := writeTable(nonLocalMounts, coreMountConfigPath)
if err != nil {
return err
}
// Write local mounts
err = writeTable(localMounts, coreLocalMountConfigPath)
if err != nil {
return err
}
case *local:
// Write local mounts
err = writeTable(localMounts, coreLocalMountConfigPath)
default:
// Write non-local mounts
err = writeTable(nonLocalMounts, coreMountConfigPath)
} }
entry := &Entry{ return err
Key: coreLocalMountConfigPath,
Value: compressedBytes,
}
if err := c.barrier.Put(ctx, entry); err != nil {
c.logger.Error("failed to persist local mount table", "error", err)
return err
}
return nil
} }
// setupMounts is invoked after we've loaded the mount table to // setupMounts is invoked after we've loaded the mount table to

View File

@ -161,7 +161,7 @@ func TestCore_Mount_Local(t *testing.T) {
} }
c.mounts.Entries[1].Local = true c.mounts.Entries[1].Local = true
if err := c.persistMounts(context.Background(), c.mounts, false); err != nil { if err := c.persistMounts(context.Background(), c.mounts, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -557,7 +557,7 @@ func testCore_MountTable_UpgradeToTyped_Common(
t.Fatal(err) t.Fatal(err)
} }
var persistFunc func(context.Context, *MountTable, bool) error var persistFunc func(context.Context, *MountTable, *bool) error
// It should load successfully and be upgraded and persisted // It should load successfully and be upgraded and persisted
switch testType { switch testType {
@ -571,7 +571,13 @@ func testCore_MountTable_UpgradeToTyped_Common(
mt = c.auth mt = c.auth
case "audits": case "audits":
err = c.loadAudits(context.Background()) err = c.loadAudits(context.Background())
persistFunc = c.persistAudit persistFunc = func(ctx context.Context, mt *MountTable, b *bool) error {
if b == nil {
b = new(bool)
*b = false
}
return c.persistAudit(ctx, mt, *b)
}
mt = c.audit mt = c.audit
} }
if err != nil { if err != nil {
@ -600,19 +606,19 @@ func testCore_MountTable_UpgradeToTyped_Common(
// Now try saving invalid versions // Now try saving invalid versions
origTableType := mt.Type origTableType := mt.Type
mt.Type = "foo" mt.Type = "foo"
if err := persistFunc(context.Background(), mt, false); err == nil { if err := persistFunc(context.Background(), mt, nil); err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
if len(mt.Entries) > 0 { if len(mt.Entries) > 0 {
mt.Type = origTableType mt.Type = origTableType
mt.Entries[0].Table = "bar" mt.Entries[0].Table = "bar"
if err := persistFunc(context.Background(), mt, false); err == nil { if err := persistFunc(context.Background(), mt, nil); err == nil {
t.Fatal("expected error") t.Fatal("expected error")
} }
mt.Entries[0].Table = mt.Type mt.Entries[0].Table = mt.Type
if err := persistFunc(context.Background(), mt, false); err != nil { if err := persistFunc(context.Background(), mt, nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@ -121,9 +121,9 @@ to specify where the configuration is.
allows the decryption/encryption of raw data into and out of the security allows the decryption/encryption of raw data into and out of the security
barrier. This is a highly privileged endpoint. barrier. This is a highly privileged endpoint.
- `ui` `(bool: false, Enterprise-only)` Enables the built-in web UI, which is - `ui` `(bool: false)` Enables the built-in web UI, which is
available on all listeners (address + port) at the `/ui` path. Browsers accessing available on all listeners (address + port) at the `/ui` path. (Vault Enterprise, or Vault OSS 0.10+)
the standard Vault API address will automatically redirect there. This can also Browsers accessing the standard Vault API address will automatically redirect there. This can also
be provided via the environment variable `VAULT_UI`. be provided via the environment variable `VAULT_UI`.
- `pid_file` `(string: "")` - Path to the file in which the Vault server's - `pid_file` `(string: "")` - Path to the file in which the Vault server's

View File

@ -111,7 +111,7 @@ server_.
During initialization, the encryption keys are generated, unseal keys are During initialization, the encryption keys are generated, unseal keys are
created, and the initial root token is setup. To initialize Vault use `vault created, and the initial root token is setup. To initialize Vault use `vault
init`. This is an _unauthenticated_ request, but it only works on brand new operator init`. This is an _unauthenticated_ request, but it only works on brand new
Vaults with no data: Vaults with no data:
```text ```text