mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-05 12:26:34 +02:00
Make predict it's own struct
The previous architecture would create an API client many times, slowing down the CLI exponentially for each new command added.
This commit is contained in:
parent
fb81547a3a
commit
3a1479bc8c
118
command/base.go
118
command/base.go
@ -48,85 +48,75 @@ type BaseCommand struct {
|
||||
|
||||
tokenHelper TokenHelperFunc
|
||||
|
||||
client *api.Client
|
||||
clientErr error
|
||||
clientOnce sync.Once
|
||||
// For testing
|
||||
client *api.Client
|
||||
}
|
||||
|
||||
// Client returns the HTTP API client. The client is cached on the command to
|
||||
// save performance on future calls.
|
||||
func (c *BaseCommand) Client() (*api.Client, error) {
|
||||
c.clientOnce.Do(func() {
|
||||
// This should never happen in reality and is just for testing. Nothing
|
||||
// should be setting the underlying client.
|
||||
if c.client != nil {
|
||||
return
|
||||
// Read the test client if present
|
||||
if c.client != nil {
|
||||
return c.client, nil
|
||||
}
|
||||
|
||||
config := api.DefaultConfig()
|
||||
|
||||
if err := config.ReadEnvironment(); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read environment")
|
||||
}
|
||||
|
||||
if c.flagAddress != "" {
|
||||
config.Address = c.flagAddress
|
||||
}
|
||||
|
||||
// If we need custom TLS configuration, then set it
|
||||
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
|
||||
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
|
||||
t := &api.TLSConfig{
|
||||
CACert: c.flagCACert,
|
||||
CAPath: c.flagCAPath,
|
||||
ClientCert: c.flagClientCert,
|
||||
ClientKey: c.flagClientKey,
|
||||
TLSServerName: c.flagTLSServerName,
|
||||
Insecure: c.flagTLSSkipVerify,
|
||||
}
|
||||
config.ConfigureTLS(t)
|
||||
}
|
||||
|
||||
config := api.DefaultConfig()
|
||||
// Build the client
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create client")
|
||||
}
|
||||
|
||||
if err := config.ReadEnvironment(); err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to read environment")
|
||||
return
|
||||
}
|
||||
// Set the wrapping function
|
||||
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
|
||||
|
||||
if c.flagAddress != "" {
|
||||
config.Address = c.flagAddress
|
||||
}
|
||||
// Get the token if it came in from the environment
|
||||
token := client.Token()
|
||||
|
||||
// If we need custom TLS configuration, then set it
|
||||
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
|
||||
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
|
||||
t := &api.TLSConfig{
|
||||
CACert: c.flagCACert,
|
||||
CAPath: c.flagCAPath,
|
||||
ClientCert: c.flagClientCert,
|
||||
ClientKey: c.flagClientKey,
|
||||
TLSServerName: c.flagTLSServerName,
|
||||
Insecure: c.flagTLSSkipVerify,
|
||||
// If we don't have a token, check the token helper
|
||||
if token == "" {
|
||||
if c.tokenHelper != nil {
|
||||
// If we have a token, then set that
|
||||
tokenHelper, err := c.tokenHelper()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get token helper")
|
||||
}
|
||||
config.ConfigureTLS(t)
|
||||
}
|
||||
|
||||
// Build the client
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to create client")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the wrapping function
|
||||
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
|
||||
|
||||
// Get the token if it came in from the environment
|
||||
token := client.Token()
|
||||
|
||||
// If we don't have a token, check the token helper
|
||||
if token == "" {
|
||||
if c.tokenHelper != nil {
|
||||
// If we have a token, then set that
|
||||
tokenHelper, err := c.tokenHelper()
|
||||
if err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to get token helper")
|
||||
return
|
||||
}
|
||||
token, err = tokenHelper.Get()
|
||||
if err != nil {
|
||||
c.clientErr = errors.Wrap(err, "failed to retrieve from token helper")
|
||||
return
|
||||
}
|
||||
token, err = tokenHelper.Get()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to retrieve from token helper")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the token
|
||||
if token != "" {
|
||||
client.SetToken(token)
|
||||
}
|
||||
// Set the token
|
||||
if token != "" {
|
||||
client.SetToken(token)
|
||||
}
|
||||
|
||||
c.client = client
|
||||
})
|
||||
|
||||
return c.client, c.clientErr
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// DefaultWrappingLookupFunc is the default wrapping function based on the
|
||||
|
||||
@ -3,11 +3,30 @@ package command
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/posener/complete"
|
||||
)
|
||||
|
||||
type Predict struct {
|
||||
client *api.Client
|
||||
clientOnce sync.Once
|
||||
}
|
||||
|
||||
func NewPredict() *Predict {
|
||||
return &Predict{}
|
||||
}
|
||||
|
||||
func (p *Predict) Client() *api.Client {
|
||||
p.clientOnce.Do(func() {
|
||||
if p.client == nil { // For tests
|
||||
p.client, _ = api.NewClient(nil)
|
||||
}
|
||||
})
|
||||
return p.client
|
||||
}
|
||||
|
||||
// defaultPredictVaultMounts is the default list of mounts to return to the
|
||||
// user. This is a best-guess, given we haven't communicated with the Vault
|
||||
// server. If the user has no token or if the token does not have the default
|
||||
@ -15,48 +34,63 @@ import (
|
||||
// that returning nothing.
|
||||
var defaultPredictVaultMounts = []string{"cubbyhole/"}
|
||||
|
||||
// predictClient is the API client to use for prediction. We create this at the
|
||||
// beginning once, because completions are generated for each command (and this
|
||||
// doesn't change), and the only way to configure the predict/autocomplete
|
||||
// client is via environment variables. Even if the user specifies a flag, we
|
||||
// can't parse that flag until after the command is submitted.
|
||||
var predictClient *api.Client
|
||||
var predictClientOnce sync.Once
|
||||
|
||||
// PredictClient returns the cached API client for the predictor.
|
||||
func PredictClient() *api.Client {
|
||||
predictClientOnce.Do(func() {
|
||||
if predictClient == nil { // For tests
|
||||
predictClient, _ = api.NewClient(nil)
|
||||
}
|
||||
})
|
||||
return predictClient
|
||||
}
|
||||
|
||||
// PredictVaultFiles returns a predictor for Vault mounts and paths based on the
|
||||
// configured client for the base command. Unfortunately this happens pre-flag
|
||||
// parsing, so users must rely on environment variables for autocomplete if they
|
||||
// are not using Vault at the default endpoints.
|
||||
func (b *BaseCommand) PredictVaultFiles() complete.Predictor {
|
||||
client, err := b.Client()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return PredictVaultFiles(client)
|
||||
return NewPredict().VaultFiles()
|
||||
}
|
||||
|
||||
// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles
|
||||
// for more information and restrictions.
|
||||
func (b *BaseCommand) PredictVaultFolders() complete.Predictor {
|
||||
client, err := b.Client()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return PredictVaultFolders(client)
|
||||
return NewPredict().VaultFolders()
|
||||
}
|
||||
|
||||
// PredictVaultFiles returns a predictor for Vault "files". This is a public API
|
||||
// for consumers, but you probably want BaseCommand.PredictVaultFiles instead.
|
||||
func PredictVaultFiles(client *api.Client) complete.Predictor {
|
||||
return predictVaultPaths(client, true)
|
||||
// VaultFiles returns a predictor for Vault "files". This is a public API for
|
||||
// consumers, but you probably want BaseCommand.PredictVaultFiles instead.
|
||||
func (p *Predict) VaultFiles() complete.Predictor {
|
||||
return p.vaultPaths(true)
|
||||
}
|
||||
|
||||
// PredictVaultFolders returns a predictor for Vault "folders". This is a public
|
||||
// VaultFolders returns a predictor for Vault "folders". This is a public
|
||||
// API for consumers, but you probably want BaseCommand.PredictVaultFolders
|
||||
// instead.
|
||||
func PredictVaultFolders(client *api.Client) complete.Predictor {
|
||||
return predictVaultPaths(client, false)
|
||||
func (p *Predict) VaultFolders() complete.Predictor {
|
||||
return p.vaultPaths(false)
|
||||
}
|
||||
|
||||
// predictVaultPaths parses the CLI options and returns the "best" list of
|
||||
// possible paths. If there are any errors, this function returns an empty
|
||||
// result. All errors are suppressed since this is a prediction function.
|
||||
func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFunc {
|
||||
// vaultPaths parses the CLI options and returns the "best" list of possible
|
||||
// paths. If there are any errors, this function returns an empty result. All
|
||||
// errors are suppressed since this is a prediction function.
|
||||
func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
|
||||
return func(args complete.Args) []string {
|
||||
// Do not predict more than one paths
|
||||
if predictHasPathArg(args.All) {
|
||||
if p.hasPathArg(args.All) {
|
||||
return nil
|
||||
}
|
||||
|
||||
client := p.Client()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -64,9 +98,9 @@ func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFu
|
||||
|
||||
var predictions []string
|
||||
if strings.Contains(path, "/") {
|
||||
predictions = predictPaths(client, path, includeFiles)
|
||||
predictions = p.paths(path, includeFiles)
|
||||
} else {
|
||||
predictions = predictMounts(client, path)
|
||||
predictions = p.mounts(path)
|
||||
}
|
||||
|
||||
// Either no results or many results, so return.
|
||||
@ -87,14 +121,19 @@ func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFu
|
||||
|
||||
// Re-predict with the remaining path
|
||||
args.Last = predictions[0]
|
||||
return predictVaultPaths(client, includeFiles).Predict(args)
|
||||
return p.vaultPaths(includeFiles).Predict(args)
|
||||
}
|
||||
}
|
||||
|
||||
// predictMounts predicts all mounts which start with the given prefix. These
|
||||
// are predicted on mount path, not "type".
|
||||
func predictMounts(client *api.Client, path string) []string {
|
||||
mounts := predictListMounts(client)
|
||||
// mounts predicts all mounts which start with the given prefix. These are
|
||||
// predicted on mount path, not "type".
|
||||
func (p *Predict) mounts(path string) []string {
|
||||
client := p.Client()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mounts := p.listMounts()
|
||||
|
||||
var predictions []string
|
||||
for _, m := range mounts {
|
||||
@ -106,8 +145,13 @@ func predictMounts(client *api.Client, path string) []string {
|
||||
return predictions
|
||||
}
|
||||
|
||||
// predictPaths predicts all paths which start with the given path.
|
||||
func predictPaths(client *api.Client, path string, includeFiles bool) []string {
|
||||
// paths predicts all paths which start with the given path.
|
||||
func (p *Predict) paths(path string, includeFiles bool) []string {
|
||||
client := p.Client()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Vault does not support listing based on a sub-key, so we have to back-pedal
|
||||
// to the last "/" and return all paths on that "folder". Then we perform
|
||||
// client-side filtering.
|
||||
@ -117,7 +161,7 @@ func predictPaths(client *api.Client, path string, includeFiles bool) []string {
|
||||
root = root[:idx+1]
|
||||
}
|
||||
|
||||
paths := predictListPaths(client, root)
|
||||
paths := p.listPaths(root)
|
||||
|
||||
var predictions []string
|
||||
for _, p := range paths {
|
||||
@ -140,11 +184,16 @@ func predictPaths(client *api.Client, path string, includeFiles bool) []string {
|
||||
return predictions
|
||||
}
|
||||
|
||||
// predictListMounts returns a sorted list of the mount paths for Vault server
|
||||
// for which the client is configured to communicate with. This function returns
|
||||
// the default list of mounts if an error occurs.
|
||||
func predictListMounts(c *api.Client) []string {
|
||||
mounts, err := c.Sys().ListMounts()
|
||||
// listMounts returns a sorted list of the mount paths for Vault server for
|
||||
// which the client is configured to communicate with. This function returns the
|
||||
// default list of mounts if an error occurs.
|
||||
func (p *Predict) listMounts() []string {
|
||||
client := p.Client()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mounts, err := client.Sys().ListMounts()
|
||||
if err != nil {
|
||||
return defaultPredictVaultMounts
|
||||
}
|
||||
@ -157,10 +206,15 @@ func predictListMounts(c *api.Client) []string {
|
||||
return list
|
||||
}
|
||||
|
||||
// predictListPaths returns a list of paths (HTTP LIST) for the given path. This
|
||||
// listPaths returns a list of paths (HTTP LIST) for the given path. This
|
||||
// function returns an empty list of any errors occur.
|
||||
func predictListPaths(c *api.Client, path string) []string {
|
||||
secret, err := c.Logical().List(path)
|
||||
func (p *Predict) listPaths(path string) []string {
|
||||
client := p.Client()
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
secret, err := client.Logical().List(path)
|
||||
if err != nil || secret == nil || secret.Data == nil {
|
||||
return nil
|
||||
}
|
||||
@ -180,8 +234,8 @@ func predictListPaths(c *api.Client, path string) []string {
|
||||
return list
|
||||
}
|
||||
|
||||
// predictHasPathArg determines if the args have already accepted a path.
|
||||
func predictHasPathArg(args []string) bool {
|
||||
// hasPathArg determines if the args have already accepted a path.
|
||||
func (p *Predict) hasPathArg(args []string) bool {
|
||||
var nonFlags []string
|
||||
for _, a := range args {
|
||||
if !strings.HasPrefix(a, "-") {
|
||||
|
||||
@ -189,7 +189,10 @@ func TestPredictVaultPaths(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
f := predictVaultPaths(client, tc.includeFiles)
|
||||
p := NewPredict()
|
||||
p.client = client
|
||||
|
||||
f := p.vaultPaths(tc.includeFiles)
|
||||
act := f(tc.args)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
@ -199,7 +202,7 @@ func TestPredictVaultPaths(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictMounts(t *testing.T) {
|
||||
func TestPredict_Mounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
@ -233,7 +236,10 @@ func TestPredictMounts(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictMounts(client, tc.path)
|
||||
p := NewPredict()
|
||||
p.client = client
|
||||
|
||||
act := p.mounts(tc.path)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
@ -242,7 +248,7 @@ func TestPredictMounts(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictPaths(t *testing.T) {
|
||||
func TestPredict_Paths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
@ -303,7 +309,10 @@ func TestPredictPaths(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictPaths(client, tc.path, tc.includeFiles)
|
||||
p := NewPredict()
|
||||
p.client = client
|
||||
|
||||
act := p.paths(tc.path, tc.includeFiles)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
@ -312,7 +321,7 @@ func TestPredictPaths(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictListMounts(t *testing.T) {
|
||||
func TestPredict_ListMounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
@ -345,7 +354,10 @@ func TestPredictListMounts(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictListMounts(tc.client)
|
||||
p := NewPredict()
|
||||
p.client = client
|
||||
|
||||
act := p.listMounts()
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
@ -354,7 +366,7 @@ func TestPredictListMounts(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictListPaths(t *testing.T) {
|
||||
func TestPredict_ListPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, closer := testVaultServer(t)
|
||||
@ -394,7 +406,10 @@ func TestPredictListPaths(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
act := predictListPaths(tc.client, tc.path)
|
||||
p := NewPredict()
|
||||
p.client = client
|
||||
|
||||
act := p.listPaths(tc.path)
|
||||
if !reflect.DeepEqual(act, tc.exp) {
|
||||
t.Errorf("expected %q to be %q", act, tc.exp)
|
||||
}
|
||||
@ -403,7 +418,7 @@ func TestPredictListPaths(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPredictHasPathArg(t *testing.T) {
|
||||
func TestPredict_HasPathArg(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
@ -443,7 +458,8 @@ func TestPredictHasPathArg(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if act := predictHasPathArg(tc.args); act != tc.exp {
|
||||
p := NewPredict()
|
||||
if act := p.hasPathArg(tc.args); act != tc.exp {
|
||||
t.Errorf("expected %t to be %t", act, tc.exp)
|
||||
}
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user