From 5e35356f9fba51c72da0e2dc653796a3edac9453 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 11 Apr 2018 10:33:40 -0400 Subject: [PATCH 1/5] Remove UTC call from SQL creds helper (#4336) Unix() by definition is always number of seconds since Unix epoch UTC. --- plugins/helper/database/credsutil/sql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/helper/database/credsutil/sql.go b/plugins/helper/database/credsutil/sql.go index af9a746b63..2f9cc7d19e 100644 --- a/plugins/helper/database/credsutil/sql.go +++ b/plugins/helper/database/credsutil/sql.go @@ -50,7 +50,7 @@ func (scp *SQLCredentialsProducer) GenerateUsername(config dbplugin.UsernameConf } username = fmt.Sprintf("%s%s%s", username, scp.Separator, userUUID) - username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().UTC().Unix())) + username = fmt.Sprintf("%s%s%s", username, scp.Separator, fmt.Sprint(time.Now().Unix())) if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { username = username[:scp.UsernameLen] } From 898f710d901329f0974fcaf84cfa1ce54f1dcaf3 Mon Sep 17 00:00:00 2001 From: Calvin Leung Huang Date: Wed, 11 Apr 2018 14:26:35 -0400 Subject: [PATCH 2/5] Dockerize radius auth backend acceptance tests (#4276) --- builtin/credential/radius/backend_test.go | 145 ++++++++++++++++------ 1 file changed, 104 insertions(+), 41 deletions(-) diff --git a/builtin/credential/radius/backend_test.go b/builtin/credential/radius/backend_test.go index 300284b9f3..76effcb857 100644 --- a/builtin/credential/radius/backend_test.go +++ b/builtin/credential/radius/backend_test.go @@ -5,18 +5,71 @@ import ( "fmt" "os" "reflect" + "strconv" "testing" "time" "github.com/hashicorp/vault/logical" logicaltest "github.com/hashicorp/vault/logical/testing" + dockertest "gopkg.in/ory-am/dockertest.v3" ) const ( testSysTTL = time.Hour * 10 testSysMaxTTL = time.Hour * 20 + + envRadiusRadiusHost = "RADIUS_HOST" + envRadiusPort = "RADIUS_PORT" + envRadiusSecret = "RADIUS_SECRET" + envRadiusUsername = "RADIUS_USERNAME" + envRadiusUserPass = "RADIUS_USERPASS" ) +func prepareRadiusTestContainer(t *testing.T) (func(), string, int) { + if os.Getenv(envRadiusRadiusHost) != "" { + port, _ := strconv.Atoi(os.Getenv(envRadiusPort)) + return func() {}, os.Getenv(envRadiusRadiusHost), port + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + runOpts := &dockertest.RunOptions{ + Repository: "jumanjiman/radiusd", + Cmd: []string{"-f", "-l", "stdout"}, + ExposedPorts: []string{"1812/udp"}, + Tag: "latest", + } + resource, err := pool.RunWithOptions(runOpts) + if err != nil { + t.Fatalf("Could not start local radius docker container: %s", err) + } + + cleanup := func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + port, _ := strconv.Atoi(resource.GetPort("1812/udp")) + address := fmt.Sprintf("127.0.0.1") + + // exponential backoff-retry + if err = pool.Retry(func() error { + // There's no straightfoward way to check the state, but the server starts + // up quick so a 2 second sleep should be enough. + time.Sleep(2 * time.Second) + return nil + }); err != nil { + cleanup() + t.Fatalf("Could not connect to radius docker container: %s", err) + } + return cleanup, address, port +} + func TestBackend_Config(t *testing.T) { b, err := Factory(context.Background(), &logical.BackendConfig{ Logger: nil, @@ -29,43 +82,43 @@ func TestBackend_Config(t *testing.T) { t.Fatalf("Unable to create backend: %s", err) } - config_data_basic := map[string]interface{}{ + configDataBasic := map[string]interface{}{ "host": "test.radius.hostname.com", "secret": "test-secret", } - config_data_missingrequired := map[string]interface{}{ + configDataMissingRequired := map[string]interface{}{ "host": "test.radius.hostname.com", } - config_data_invalidport := map[string]interface{}{ + configDataEmptyPort := map[string]interface{}{ + "host": "test.radius.hostname.com", + "port": "", + "secret": "test-secret", + } + + configDataInvalidPort := map[string]interface{}{ "host": "test.radius.hostname.com", "port": "notnumeric", "secret": "test-secret", } - config_data_invalidbool := map[string]interface{}{ + configDataInvalidBool := map[string]interface{}{ "host": "test.radius.hostname.com", "secret": "test-secret", "unregistered_user_policies": "test", } - config_data_emptyport := map[string]interface{}{ - "host": "test.radius.hostname.com", - "port": "", - "secret": "test-secret", - } - logicaltest.Test(t, logicaltest.TestCase{ AcceptanceTest: false, // PreCheck: func() { testAccPreCheck(t) }, Backend: b, Steps: []logicaltest.TestStep{ - testConfigWrite(t, config_data_basic, false), - testConfigWrite(t, config_data_emptyport, true), - testConfigWrite(t, config_data_invalidport, true), - testConfigWrite(t, config_data_invalidbool, true), - testConfigWrite(t, config_data_missingrequired, true), + testConfigWrite(t, configDataBasic, false), + testConfigWrite(t, configDataMissingRequired, true), + testConfigWrite(t, configDataEmptyPort, true), + testConfigWrite(t, configDataInvalidPort, true), + testConfigWrite(t, configDataInvalidBool, true), }, }) } @@ -93,7 +146,6 @@ func TestBackend_users(t *testing.T) { } func TestBackend_acceptance(t *testing.T) { - if os.Getenv(logicaltest.TestEnvVar) == "" { t.Skip(fmt.Sprintf("Acceptance tests skipped unless env '%s' set", logicaltest.TestEnvVar)) return @@ -110,10 +162,29 @@ func TestBackend_acceptance(t *testing.T) { t.Fatalf("Unable to create backend: %s", err) } + cleanup, host, port := prepareRadiusTestContainer(t) + defer cleanup() + + // These defaults are specific to the jumanjiman/radiusd docker image + username := os.Getenv(envRadiusUsername) + if username == "" { + username = "test" + } + + password := os.Getenv(envRadiusUserPass) + if password == "" { + password = "test" + } + + secret := os.Getenv(envRadiusSecret) + if len(secret) == 0 { + secret = "testing123" + } + configDataAcceptanceAllowUnreg := map[string]interface{}{ - "host": os.Getenv("RADIUS_HOST"), - "port": os.Getenv("RADIUS_PORT"), - "secret": os.Getenv("RADIUS_SECRET"), + "host": host, + "port": strconv.Itoa(port), + "secret": secret, "unregistered_user_policies": "policy1,policy2", } if configDataAcceptanceAllowUnreg["port"] == "" { @@ -121,9 +192,9 @@ func TestBackend_acceptance(t *testing.T) { } configDataAcceptanceNoAllowUnreg := map[string]interface{}{ - "host": os.Getenv("RADIUS_HOST"), - "port": os.Getenv("RADIUS_PORT"), - "secret": os.Getenv("RADIUS_SECRET"), + "host": host, + "port": strconv.Itoa(port), + "secret": secret, "unregistered_user_policies": "", } if configDataAcceptanceNoAllowUnreg["port"] == "" { @@ -131,18 +202,16 @@ func TestBackend_acceptance(t *testing.T) { } dataRealpassword := map[string]interface{}{ - "password": os.Getenv("RADIUS_USERPASS"), + "password": password, } dataWrongpassword := map[string]interface{}{ "password": "wrongpassword", } - username := os.Getenv("RADIUS_USERNAME") - logicaltest.Test(t, logicaltest.TestCase{ Backend: b, - PreCheck: func() { testAccPreCheck(t) }, + PreCheck: testAccPreCheck(t, host, port), AcceptanceTest: true, Steps: []logicaltest.TestStep{ // Login with valid but unknown user will fail because unregistered_user_policies is emtpy @@ -172,21 +241,15 @@ func TestBackend_acceptance(t *testing.T) { }) } -func testAccPreCheck(t *testing.T) { - if v := os.Getenv("RADIUS_HOST"); v == "" { - t.Fatal("RADIUS_HOST must be set for acceptance tests") - } +func testAccPreCheck(t *testing.T, host string, port int) func() { + return func() { + if host == "" { + t.Fatal("Host must be set for acceptance tests") + } - if v := os.Getenv("RADIUS_USERNAME"); v == "" { - t.Fatal("RADIUS_USERNAME must be set for acceptance tests") - } - - if v := os.Getenv("RADIUS_USERPASS"); v == "" { - t.Fatal("RADIUS_USERPASS must be set for acceptance tests") - } - - if v := os.Getenv("RADIUS_SECRET"); v == "" { - t.Fatal("RADIUS_SECRET must be set for acceptance tests") + if port == 0 { + t.Fatal("Port must be non-zero for acceptance tests") + } } } @@ -249,7 +312,7 @@ func testAccUserLoginPolicy(t *testing.T, user string, data map[string]interface Operation: logical.UpdateOperation, Path: "login/" + user, Data: data, - ErrorOk: false, + ErrorOk: expectError, Unauthenticated: true, //Check: logicaltest.TestCheckAuth(policies), Check: func(resp *logical.Response) error { From 986ace5183d30d2d81a752d969ca78f314df74d1 Mon Sep 17 00:00:00 2001 From: James Mannion Date: Wed, 11 Apr 2018 19:26:53 +0100 Subject: [PATCH 3/5] Fixes a reference to deprecated init command (#4338) Replace "vault init" with "vault operator init" in initialising the vault section. --- website/source/intro/getting-started/deploy.html.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/source/intro/getting-started/deploy.html.md b/website/source/intro/getting-started/deploy.html.md index 3e8ea96f6e..9fa9d37b5d 100644 --- a/website/source/intro/getting-started/deploy.html.md +++ b/website/source/intro/getting-started/deploy.html.md @@ -111,7 +111,7 @@ server_. During initialization, the encryption keys are generated, unseal keys are created, and the initial root token is setup. To initialize Vault use `vault -init`. This is an _unauthenticated_ request, but it only works on brand new +operator init`. This is an _unauthenticated_ request, but it only works on brand new Vaults with no data: ```text From 0ac5933c24f2154907ddb2916400ab84451a6e33 Mon Sep 17 00:00:00 2001 From: Peter Souter Date: Wed, 11 Apr 2018 19:27:58 +0100 Subject: [PATCH 4/5] Remove Enterprise Only flag (#4337) --- website/source/docs/configuration/index.html.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/website/source/docs/configuration/index.html.md b/website/source/docs/configuration/index.html.md index 4ea3c0499b..76aa8a016d 100644 --- a/website/source/docs/configuration/index.html.md +++ b/website/source/docs/configuration/index.html.md @@ -121,9 +121,9 @@ to specify where the configuration is. allows the decryption/encryption of raw data into and out of the security barrier. This is a highly privileged endpoint. -- `ui` `(bool: false, Enterprise-only)` – Enables the built-in web UI, which is - available on all listeners (address + port) at the `/ui` path. Browsers accessing - the standard Vault API address will automatically redirect there. This can also +- `ui` `(bool: false)` – Enables the built-in web UI, which is + available on all listeners (address + port) at the `/ui` path. (Vault Enterprise, or Vault OSS 0.10+) + Browsers accessing the standard Vault API address will automatically redirect there. This can also be provided via the environment variable `VAULT_UI`. - `pid_file` `(string: "")` - Path to the file in which the Vault server's From 92de42170c20d958bffe6fe2252e7fe4d6147d40 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 11 Apr 2018 11:32:55 -0700 Subject: [PATCH 5/5] Port some ent mount changes (#4330) --- vault/auth.go | 60 ++++++++++++++++++--------------- vault/auth_test.go | 2 +- vault/logical_system.go | 24 ++++++------- vault/logical_system_helpers.go | 4 +-- vault/mount.go | 59 ++++++++++++++++++-------------- vault/mount_test.go | 18 ++++++---- 6 files changed, 92 insertions(+), 75 deletions(-) diff --git a/vault/auth.go b/vault/auth.go index e94b85d1e1..d75475a050 100644 --- a/vault/auth.go +++ b/vault/auth.go @@ -134,7 +134,7 @@ func (c *Core) enableCredential(ctx context.Context, entry *MountEntry) error { // Update the auth table newTable := c.auth.shallowClone() newTable.Entries = append(newTable.Entries, entry) - if err := c.persistAuth(ctx, newTable, entry.Local); err != nil { + if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil { return errors.New("failed to update auth table") } @@ -235,7 +235,7 @@ func (c *Core) removeCredEntry(ctx context.Context, path string) error { } // Update the auth table - if err := c.persistAuth(ctx, newTable, entry.Local); err != nil { + if err := c.persistAuth(ctx, newTable, &entry.Local); err != nil { return errors.New("failed to update auth table") } @@ -281,7 +281,7 @@ func (c *Core) taintCredEntry(ctx context.Context, path string) error { } // Update the auth table - if err := c.persistAuth(ctx, c.auth, entry.Local); err != nil { + if err := c.persistAuth(ctx, c.auth, &entry.Local); err != nil { return errors.New("failed to update auth table") } @@ -369,7 +369,7 @@ func (c *Core) loadCredentials(ctx context.Context) error { return nil } - if err := c.persistAuth(ctx, c.auth, false); err != nil { + if err := c.persistAuth(ctx, c.auth, nil); err != nil { c.logger.Error("failed to persist auth table", "error", err) return errLoadAuthFailed } @@ -377,7 +377,7 @@ func (c *Core) loadCredentials(ctx context.Context) error { } // persistAuth is used to persist the auth table after modification -func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly bool) error { +func (c *Core) persistAuth(ctx context.Context, table *MountTable, local *bool) error { if table.Type != credentialTableType { c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", credentialTableType) return fmt.Errorf("invalid table type given, not persisting") @@ -406,45 +406,49 @@ func (c *Core) persistAuth(ctx context.Context, table *MountTable, localOnly boo } } - if !localOnly { - // Marshal the table - compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalAuth, nil) + writeTable := func(mt *MountTable, path string) error { + // Encode the mount table into JSON and compress it (lzw). + compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil) if err != nil { - c.logger.Error("failed to encode and/or compress auth table", "error", err) + c.logger.Error("failed to encode or compress auth mount table", "error", err) return err } // Create an entry entry := &Entry{ - Key: coreAuthConfigPath, + Key: path, Value: compressedBytes, } // Write to the physical backend if err := c.barrier.Put(ctx, entry); err != nil { - c.logger.Error("failed to persist auth table", "error", err) + c.logger.Error("failed to persist auth mount table", "error", err) return err } + return nil } - // Repeat with local auth - compressedBytes, err := jsonutil.EncodeJSONAndCompress(localAuth, nil) - if err != nil { - c.logger.Error("failed to encode and/or compress local auth table", "error", err) - return err + var err error + switch { + case local == nil: + // Write non-local mounts + err := writeTable(nonLocalAuth, coreAuthConfigPath) + if err != nil { + return err + } + + // Write local mounts + err = writeTable(localAuth, coreLocalAuthConfigPath) + if err != nil { + return err + } + case *local: + err = writeTable(localAuth, coreLocalAuthConfigPath) + default: + err = writeTable(nonLocalAuth, coreAuthConfigPath) } - entry := &Entry{ - Key: coreLocalAuthConfigPath, - Value: compressedBytes, - } - - if err := c.barrier.Put(ctx, entry); err != nil { - c.logger.Error("failed to persist local auth table", "error", err) - return err - } - - return nil + return err } // setupCredentials is invoked after we've loaded the auth table to @@ -520,7 +524,7 @@ func (c *Core) setupCredentials(ctx context.Context) error { } if persistNeeded { - return c.persistAuth(ctx, c.auth, false) + return c.persistAuth(ctx, c.auth, nil) } return nil diff --git a/vault/auth_test.go b/vault/auth_test.go index 6f66b4c584..8b32275997 100644 --- a/vault/auth_test.go +++ b/vault/auth_test.go @@ -164,7 +164,7 @@ func TestCore_EnableCredential_Local(t *testing.T) { } c.auth.Entries[1].Local = true - if err := c.persistAuth(context.Background(), c.auth, false); err != nil { + if err := c.persistAuth(context.Background(), c.auth, nil); err != nil { t.Fatal(err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 298c69c070..0970075203 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1988,9 +1988,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Description = oldDesc @@ -2011,9 +2011,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.AuditNonHMACRequestKeys = oldVal @@ -2037,9 +2037,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.AuditNonHMACResponseKeys = oldVal @@ -2068,9 +2068,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.ListingVisibility = oldVal @@ -2092,9 +2092,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, var err error switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Config.PassthroughRequestHeaders = oldVal @@ -2154,9 +2154,9 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, // Update the mount table switch { case strings.HasPrefix(path, "auth/"): - err = b.Core.persistAuth(ctx, b.Core.auth, mountEntry.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, mountEntry.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) } if err != nil { mountEntry.Options = oldVal diff --git a/vault/logical_system_helpers.go b/vault/logical_system_helpers.go index 48cbb173c7..d9fdb046b7 100644 --- a/vault/logical_system_helpers.go +++ b/vault/logical_system_helpers.go @@ -37,9 +37,9 @@ func (b *SystemBackend) tuneMountTTLs(ctx context.Context, path string, me *Moun var err error switch { case strings.HasPrefix(path, credentialRoutePrefix): - err = b.Core.persistAuth(ctx, b.Core.auth, me.Local) + err = b.Core.persistAuth(ctx, b.Core.auth, &me.Local) default: - err = b.Core.persistMounts(ctx, b.Core.mounts, me.Local) + err = b.Core.persistMounts(ctx, b.Core.mounts, &me.Local) } if err != nil { me.Config.MaxLeaseTTL = origMax diff --git a/vault/mount.go b/vault/mount.go index 5ef79bf701..7aaf5d6a99 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -336,7 +336,7 @@ func (c *Core) mountInternal(ctx context.Context, entry *MountEntry) error { newTable := c.mounts.shallowClone() newTable.Entries = append(newTable.Entries, entry) - if err := c.persistMounts(ctx, newTable, entry.Local); err != nil { + if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil { c.logger.Error("failed to update mount table", "error", err) return logical.CodedError(500, "failed to update mount table") } @@ -457,7 +457,7 @@ func (c *Core) removeMountEntry(ctx context.Context, path string) error { } // Update the mount table - if err := c.persistMounts(ctx, newTable, entry.Local); err != nil { + if err := c.persistMounts(ctx, newTable, &entry.Local); err != nil { c.logger.Error("failed to remove entry from mounts table", "error", err) return logical.CodedError(500, "failed to remove entry from mounts table") } @@ -480,7 +480,7 @@ func (c *Core) taintMountEntry(ctx context.Context, path string) error { } // Update the mount table - if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil { + if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil { c.logger.Error("failed to taint entry in mounts table", "error", err) return logical.CodedError(500, "failed to taint entry in mounts table") } @@ -571,7 +571,7 @@ func (c *Core) remount(ctx context.Context, src, dst string) error { } // Update the mount table - if err := c.persistMounts(ctx, c.mounts, entry.Local); err != nil { + if err := c.persistMounts(ctx, c.mounts, &entry.Local); err != nil { entry.Path = src entry.Tainted = true c.mountsLock.Unlock() @@ -710,7 +710,7 @@ func (c *Core) loadMounts(ctx context.Context) error { return nil } - if err := c.persistMounts(ctx, c.mounts, false); err != nil { + if err := c.persistMounts(ctx, c.mounts, nil); err != nil { c.logger.Error("failed to persist mount table", "error", err) return errLoadMountsFailed } @@ -718,7 +718,7 @@ func (c *Core) loadMounts(ctx context.Context) error { } // persistMounts is used to persist the mount table after modification -func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly bool) error { +func (c *Core) persistMounts(ctx context.Context, table *MountTable, local *bool) error { if table.Type != mountTableType { c.logger.Error("given table to persist has wrong type", "actual_type", table.Type, "expected_type", mountTableType) return fmt.Errorf("invalid table type given, not persisting") @@ -747,17 +747,17 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b } } - if !localOnly { + writeTable := func(mt *MountTable, path string) error { // Encode the mount table into JSON and compress it (lzw). - compressedBytes, err := jsonutil.EncodeJSONAndCompress(nonLocalMounts, nil) + compressedBytes, err := jsonutil.EncodeJSONAndCompress(mt, nil) if err != nil { - c.logger.Error("failed to encode and/or compress the mount table", "error", err) + c.logger.Error("failed to encode or compress mount table", "error", err) return err } // Create an entry entry := &Entry{ - Key: coreMountConfigPath, + Key: path, Value: compressedBytes, } @@ -766,26 +766,33 @@ func (c *Core) persistMounts(ctx context.Context, table *MountTable, localOnly b c.logger.Error("failed to persist mount table", "error", err) return err } + + return nil } - // Repeat with local mounts - compressedBytes, err := jsonutil.EncodeJSONAndCompress(localMounts, nil) - if err != nil { - c.logger.Error("failed to encode and/or compress the local mount table", "error", err) - return err + var err error + switch { + case local == nil: + // Write non-local mounts + err := writeTable(nonLocalMounts, coreMountConfigPath) + if err != nil { + return err + } + + // Write local mounts + err = writeTable(localMounts, coreLocalMountConfigPath) + if err != nil { + return err + } + case *local: + // Write local mounts + err = writeTable(localMounts, coreLocalMountConfigPath) + default: + // Write non-local mounts + err = writeTable(nonLocalMounts, coreMountConfigPath) } - entry := &Entry{ - Key: coreLocalMountConfigPath, - Value: compressedBytes, - } - - if err := c.barrier.Put(ctx, entry); err != nil { - c.logger.Error("failed to persist local mount table", "error", err) - return err - } - - return nil + return err } // setupMounts is invoked after we've loaded the mount table to diff --git a/vault/mount_test.go b/vault/mount_test.go index 87a0f9ed29..e773003571 100644 --- a/vault/mount_test.go +++ b/vault/mount_test.go @@ -161,7 +161,7 @@ func TestCore_Mount_Local(t *testing.T) { } c.mounts.Entries[1].Local = true - if err := c.persistMounts(context.Background(), c.mounts, false); err != nil { + if err := c.persistMounts(context.Background(), c.mounts, nil); err != nil { t.Fatal(err) } @@ -557,7 +557,7 @@ func testCore_MountTable_UpgradeToTyped_Common( t.Fatal(err) } - var persistFunc func(context.Context, *MountTable, bool) error + var persistFunc func(context.Context, *MountTable, *bool) error // It should load successfully and be upgraded and persisted switch testType { @@ -571,7 +571,13 @@ func testCore_MountTable_UpgradeToTyped_Common( mt = c.auth case "audits": err = c.loadAudits(context.Background()) - persistFunc = c.persistAudit + persistFunc = func(ctx context.Context, mt *MountTable, b *bool) error { + if b == nil { + b = new(bool) + *b = false + } + return c.persistAudit(ctx, mt, *b) + } mt = c.audit } if err != nil { @@ -600,19 +606,19 @@ func testCore_MountTable_UpgradeToTyped_Common( // Now try saving invalid versions origTableType := mt.Type mt.Type = "foo" - if err := persistFunc(context.Background(), mt, false); err == nil { + if err := persistFunc(context.Background(), mt, nil); err == nil { t.Fatal("expected error") } if len(mt.Entries) > 0 { mt.Type = origTableType mt.Entries[0].Table = "bar" - if err := persistFunc(context.Background(), mt, false); err == nil { + if err := persistFunc(context.Background(), mt, nil); err == nil { t.Fatal("expected error") } mt.Entries[0].Table = mt.Type - if err := persistFunc(context.Background(), mt, false); err != nil { + if err := persistFunc(context.Background(), mt, nil); err != nil { t.Fatal(err) } }