Make predict it's own struct

The previous architecture would create an API client many times, slowing down the CLI exponentially for each new command added.
This commit is contained in:
Seth Vargo 2017-08-29 00:24:22 -04:00
parent fb81547a3a
commit 3a1479bc8c
No known key found for this signature in database
GPG Key ID: C921994F9C27E0FF
3 changed files with 177 additions and 117 deletions

View File

@ -48,85 +48,75 @@ type BaseCommand struct {
tokenHelper TokenHelperFunc
client *api.Client
clientErr error
clientOnce sync.Once
// For testing
client *api.Client
}
// Client returns the HTTP API client. The client is cached on the command to
// save performance on future calls.
func (c *BaseCommand) Client() (*api.Client, error) {
c.clientOnce.Do(func() {
// This should never happen in reality and is just for testing. Nothing
// should be setting the underlying client.
if c.client != nil {
return
// Read the test client if present
if c.client != nil {
return c.client, nil
}
config := api.DefaultConfig()
if err := config.ReadEnvironment(); err != nil {
return nil, errors.Wrap(err, "failed to read environment")
}
if c.flagAddress != "" {
config.Address = c.flagAddress
}
// If we need custom TLS configuration, then set it
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
t := &api.TLSConfig{
CACert: c.flagCACert,
CAPath: c.flagCAPath,
ClientCert: c.flagClientCert,
ClientKey: c.flagClientKey,
TLSServerName: c.flagTLSServerName,
Insecure: c.flagTLSSkipVerify,
}
config.ConfigureTLS(t)
}
config := api.DefaultConfig()
// Build the client
client, err := api.NewClient(config)
if err != nil {
return nil, errors.Wrap(err, "failed to create client")
}
if err := config.ReadEnvironment(); err != nil {
c.clientErr = errors.Wrap(err, "failed to read environment")
return
}
// Set the wrapping function
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
if c.flagAddress != "" {
config.Address = c.flagAddress
}
// Get the token if it came in from the environment
token := client.Token()
// If we need custom TLS configuration, then set it
if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" ||
c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify {
t := &api.TLSConfig{
CACert: c.flagCACert,
CAPath: c.flagCAPath,
ClientCert: c.flagClientCert,
ClientKey: c.flagClientKey,
TLSServerName: c.flagTLSServerName,
Insecure: c.flagTLSSkipVerify,
// If we don't have a token, check the token helper
if token == "" {
if c.tokenHelper != nil {
// If we have a token, then set that
tokenHelper, err := c.tokenHelper()
if err != nil {
return nil, errors.Wrap(err, "failed to get token helper")
}
config.ConfigureTLS(t)
}
// Build the client
client, err := api.NewClient(config)
if err != nil {
c.clientErr = errors.Wrap(err, "failed to create client")
return
}
// Set the wrapping function
client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc)
// Get the token if it came in from the environment
token := client.Token()
// If we don't have a token, check the token helper
if token == "" {
if c.tokenHelper != nil {
// If we have a token, then set that
tokenHelper, err := c.tokenHelper()
if err != nil {
c.clientErr = errors.Wrap(err, "failed to get token helper")
return
}
token, err = tokenHelper.Get()
if err != nil {
c.clientErr = errors.Wrap(err, "failed to retrieve from token helper")
return
}
token, err = tokenHelper.Get()
if err != nil {
return nil, errors.Wrap(err, "failed to retrieve from token helper")
}
}
}
// Set the token
if token != "" {
client.SetToken(token)
}
// Set the token
if token != "" {
client.SetToken(token)
}
c.client = client
})
return c.client, c.clientErr
return client, nil
}
// DefaultWrappingLookupFunc is the default wrapping function based on the

View File

@ -3,11 +3,30 @@ package command
import (
"sort"
"strings"
"sync"
"github.com/hashicorp/vault/api"
"github.com/posener/complete"
)
type Predict struct {
client *api.Client
clientOnce sync.Once
}
func NewPredict() *Predict {
return &Predict{}
}
func (p *Predict) Client() *api.Client {
p.clientOnce.Do(func() {
if p.client == nil { // For tests
p.client, _ = api.NewClient(nil)
}
})
return p.client
}
// defaultPredictVaultMounts is the default list of mounts to return to the
// user. This is a best-guess, given we haven't communicated with the Vault
// server. If the user has no token or if the token does not have the default
@ -15,48 +34,63 @@ import (
// that returning nothing.
var defaultPredictVaultMounts = []string{"cubbyhole/"}
// predictClient is the API client to use for prediction. We create this at the
// beginning once, because completions are generated for each command (and this
// doesn't change), and the only way to configure the predict/autocomplete
// client is via environment variables. Even if the user specifies a flag, we
// can't parse that flag until after the command is submitted.
var predictClient *api.Client
var predictClientOnce sync.Once
// PredictClient returns the cached API client for the predictor.
func PredictClient() *api.Client {
predictClientOnce.Do(func() {
if predictClient == nil { // For tests
predictClient, _ = api.NewClient(nil)
}
})
return predictClient
}
// PredictVaultFiles returns a predictor for Vault mounts and paths based on the
// configured client for the base command. Unfortunately this happens pre-flag
// parsing, so users must rely on environment variables for autocomplete if they
// are not using Vault at the default endpoints.
func (b *BaseCommand) PredictVaultFiles() complete.Predictor {
client, err := b.Client()
if err != nil {
return nil
}
return PredictVaultFiles(client)
return NewPredict().VaultFiles()
}
// PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles
// for more information and restrictions.
func (b *BaseCommand) PredictVaultFolders() complete.Predictor {
client, err := b.Client()
if err != nil {
return nil
}
return PredictVaultFolders(client)
return NewPredict().VaultFolders()
}
// PredictVaultFiles returns a predictor for Vault "files". This is a public API
// for consumers, but you probably want BaseCommand.PredictVaultFiles instead.
func PredictVaultFiles(client *api.Client) complete.Predictor {
return predictVaultPaths(client, true)
// VaultFiles returns a predictor for Vault "files". This is a public API for
// consumers, but you probably want BaseCommand.PredictVaultFiles instead.
func (p *Predict) VaultFiles() complete.Predictor {
return p.vaultPaths(true)
}
// PredictVaultFolders returns a predictor for Vault "folders". This is a public
// VaultFolders returns a predictor for Vault "folders". This is a public
// API for consumers, but you probably want BaseCommand.PredictVaultFolders
// instead.
func PredictVaultFolders(client *api.Client) complete.Predictor {
return predictVaultPaths(client, false)
func (p *Predict) VaultFolders() complete.Predictor {
return p.vaultPaths(false)
}
// predictVaultPaths parses the CLI options and returns the "best" list of
// possible paths. If there are any errors, this function returns an empty
// result. All errors are suppressed since this is a prediction function.
func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFunc {
// vaultPaths parses the CLI options and returns the "best" list of possible
// paths. If there are any errors, this function returns an empty result. All
// errors are suppressed since this is a prediction function.
func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc {
return func(args complete.Args) []string {
// Do not predict more than one paths
if predictHasPathArg(args.All) {
if p.hasPathArg(args.All) {
return nil
}
client := p.Client()
if client == nil {
return nil
}
@ -64,9 +98,9 @@ func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFu
var predictions []string
if strings.Contains(path, "/") {
predictions = predictPaths(client, path, includeFiles)
predictions = p.paths(path, includeFiles)
} else {
predictions = predictMounts(client, path)
predictions = p.mounts(path)
}
// Either no results or many results, so return.
@ -87,14 +121,19 @@ func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFu
// Re-predict with the remaining path
args.Last = predictions[0]
return predictVaultPaths(client, includeFiles).Predict(args)
return p.vaultPaths(includeFiles).Predict(args)
}
}
// predictMounts predicts all mounts which start with the given prefix. These
// are predicted on mount path, not "type".
func predictMounts(client *api.Client, path string) []string {
mounts := predictListMounts(client)
// mounts predicts all mounts which start with the given prefix. These are
// predicted on mount path, not "type".
func (p *Predict) mounts(path string) []string {
client := p.Client()
if client == nil {
return nil
}
mounts := p.listMounts()
var predictions []string
for _, m := range mounts {
@ -106,8 +145,13 @@ func predictMounts(client *api.Client, path string) []string {
return predictions
}
// predictPaths predicts all paths which start with the given path.
func predictPaths(client *api.Client, path string, includeFiles bool) []string {
// paths predicts all paths which start with the given path.
func (p *Predict) paths(path string, includeFiles bool) []string {
client := p.Client()
if client == nil {
return nil
}
// Vault does not support listing based on a sub-key, so we have to back-pedal
// to the last "/" and return all paths on that "folder". Then we perform
// client-side filtering.
@ -117,7 +161,7 @@ func predictPaths(client *api.Client, path string, includeFiles bool) []string {
root = root[:idx+1]
}
paths := predictListPaths(client, root)
paths := p.listPaths(root)
var predictions []string
for _, p := range paths {
@ -140,11 +184,16 @@ func predictPaths(client *api.Client, path string, includeFiles bool) []string {
return predictions
}
// predictListMounts returns a sorted list of the mount paths for Vault server
// for which the client is configured to communicate with. This function returns
// the default list of mounts if an error occurs.
func predictListMounts(c *api.Client) []string {
mounts, err := c.Sys().ListMounts()
// listMounts returns a sorted list of the mount paths for Vault server for
// which the client is configured to communicate with. This function returns the
// default list of mounts if an error occurs.
func (p *Predict) listMounts() []string {
client := p.Client()
if client == nil {
return nil
}
mounts, err := client.Sys().ListMounts()
if err != nil {
return defaultPredictVaultMounts
}
@ -157,10 +206,15 @@ func predictListMounts(c *api.Client) []string {
return list
}
// predictListPaths returns a list of paths (HTTP LIST) for the given path. This
// listPaths returns a list of paths (HTTP LIST) for the given path. This
// function returns an empty list of any errors occur.
func predictListPaths(c *api.Client, path string) []string {
secret, err := c.Logical().List(path)
func (p *Predict) listPaths(path string) []string {
client := p.Client()
if client == nil {
return nil
}
secret, err := client.Logical().List(path)
if err != nil || secret == nil || secret.Data == nil {
return nil
}
@ -180,8 +234,8 @@ func predictListPaths(c *api.Client, path string) []string {
return list
}
// predictHasPathArg determines if the args have already accepted a path.
func predictHasPathArg(args []string) bool {
// hasPathArg determines if the args have already accepted a path.
func (p *Predict) hasPathArg(args []string) bool {
var nonFlags []string
for _, a := range args {
if !strings.HasPrefix(a, "-") {

View File

@ -189,7 +189,10 @@ func TestPredictVaultPaths(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
f := predictVaultPaths(client, tc.includeFiles)
p := NewPredict()
p.client = client
f := p.vaultPaths(tc.includeFiles)
act := f(tc.args)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
@ -199,7 +202,7 @@ func TestPredictVaultPaths(t *testing.T) {
})
}
func TestPredictMounts(t *testing.T) {
func TestPredict_Mounts(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
@ -233,7 +236,10 @@ func TestPredictMounts(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictMounts(client, tc.path)
p := NewPredict()
p.client = client
act := p.mounts(tc.path)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
@ -242,7 +248,7 @@ func TestPredictMounts(t *testing.T) {
})
}
func TestPredictPaths(t *testing.T) {
func TestPredict_Paths(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
@ -303,7 +309,10 @@ func TestPredictPaths(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictPaths(client, tc.path, tc.includeFiles)
p := NewPredict()
p.client = client
act := p.paths(tc.path, tc.includeFiles)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
@ -312,7 +321,7 @@ func TestPredictPaths(t *testing.T) {
})
}
func TestPredictListMounts(t *testing.T) {
func TestPredict_ListMounts(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
@ -345,7 +354,10 @@ func TestPredictListMounts(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictListMounts(tc.client)
p := NewPredict()
p.client = client
act := p.listMounts()
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
@ -354,7 +366,7 @@ func TestPredictListMounts(t *testing.T) {
})
}
func TestPredictListPaths(t *testing.T) {
func TestPredict_ListPaths(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
@ -394,7 +406,10 @@ func TestPredictListPaths(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := predictListPaths(tc.client, tc.path)
p := NewPredict()
p.client = client
act := p.listPaths(tc.path)
if !reflect.DeepEqual(act, tc.exp) {
t.Errorf("expected %q to be %q", act, tc.exp)
}
@ -403,7 +418,7 @@ func TestPredictListPaths(t *testing.T) {
})
}
func TestPredictHasPathArg(t *testing.T) {
func TestPredict_HasPathArg(t *testing.T) {
t.Parallel()
cases := []struct {
@ -443,7 +458,8 @@ func TestPredictHasPathArg(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if act := predictHasPathArg(tc.args); act != tc.exp {
p := NewPredict()
if act := p.hasPathArg(tc.args); act != tc.exp {
t.Errorf("expected %t to be %t", act, tc.exp)
}
})