From 3a1479bc8c0191be21fd429d474462e0cc6e7303 Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Tue, 29 Aug 2017 00:24:22 -0400 Subject: [PATCH] Make predict it's own struct The previous architecture would create an API client many times, slowing down the CLI exponentially for each new command added. --- command/base.go | 118 ++++++++++++++---------------- command/base_predict.go | 138 ++++++++++++++++++++++++----------- command/base_predict_test.go | 38 +++++++--- 3 files changed, 177 insertions(+), 117 deletions(-) diff --git a/command/base.go b/command/base.go index c1e64c0239..a918dc39fd 100644 --- a/command/base.go +++ b/command/base.go @@ -48,85 +48,75 @@ type BaseCommand struct { tokenHelper TokenHelperFunc - client *api.Client - clientErr error - clientOnce sync.Once + // For testing + client *api.Client } // Client returns the HTTP API client. The client is cached on the command to // save performance on future calls. func (c *BaseCommand) Client() (*api.Client, error) { - c.clientOnce.Do(func() { - // This should never happen in reality and is just for testing. Nothing - // should be setting the underlying client. - if c.client != nil { - return + // Read the test client if present + if c.client != nil { + return c.client, nil + } + + config := api.DefaultConfig() + + if err := config.ReadEnvironment(); err != nil { + return nil, errors.Wrap(err, "failed to read environment") + } + + if c.flagAddress != "" { + config.Address = c.flagAddress + } + + // If we need custom TLS configuration, then set it + if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" || + c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify { + t := &api.TLSConfig{ + CACert: c.flagCACert, + CAPath: c.flagCAPath, + ClientCert: c.flagClientCert, + ClientKey: c.flagClientKey, + TLSServerName: c.flagTLSServerName, + Insecure: c.flagTLSSkipVerify, } + config.ConfigureTLS(t) + } - config := api.DefaultConfig() + // Build the client + client, err := api.NewClient(config) + if err != nil { + return nil, errors.Wrap(err, "failed to create client") + } - if err := config.ReadEnvironment(); err != nil { - c.clientErr = errors.Wrap(err, "failed to read environment") - return - } + // Set the wrapping function + client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc) - if c.flagAddress != "" { - config.Address = c.flagAddress - } + // Get the token if it came in from the environment + token := client.Token() - // If we need custom TLS configuration, then set it - if c.flagCACert != "" || c.flagCAPath != "" || c.flagClientCert != "" || - c.flagClientKey != "" || c.flagTLSServerName != "" || c.flagTLSSkipVerify { - t := &api.TLSConfig{ - CACert: c.flagCACert, - CAPath: c.flagCAPath, - ClientCert: c.flagClientCert, - ClientKey: c.flagClientKey, - TLSServerName: c.flagTLSServerName, - Insecure: c.flagTLSSkipVerify, + // If we don't have a token, check the token helper + if token == "" { + if c.tokenHelper != nil { + // If we have a token, then set that + tokenHelper, err := c.tokenHelper() + if err != nil { + return nil, errors.Wrap(err, "failed to get token helper") } - config.ConfigureTLS(t) - } - - // Build the client - client, err := api.NewClient(config) - if err != nil { - c.clientErr = errors.Wrap(err, "failed to create client") - return - } - - // Set the wrapping function - client.SetWrappingLookupFunc(c.DefaultWrappingLookupFunc) - - // Get the token if it came in from the environment - token := client.Token() - - // If we don't have a token, check the token helper - if token == "" { - if c.tokenHelper != nil { - // If we have a token, then set that - tokenHelper, err := c.tokenHelper() - if err != nil { - c.clientErr = errors.Wrap(err, "failed to get token helper") - return - } - token, err = tokenHelper.Get() - if err != nil { - c.clientErr = errors.Wrap(err, "failed to retrieve from token helper") - return - } + token, err = tokenHelper.Get() + if err != nil { + return nil, errors.Wrap(err, "failed to retrieve from token helper") } } + } - // Set the token - if token != "" { - client.SetToken(token) - } + // Set the token + if token != "" { + client.SetToken(token) + } - c.client = client - }) - - return c.client, c.clientErr + return client, nil } // DefaultWrappingLookupFunc is the default wrapping function based on the diff --git a/command/base_predict.go b/command/base_predict.go index 2347193523..3b59147bb7 100644 --- a/command/base_predict.go +++ b/command/base_predict.go @@ -3,11 +3,30 @@ package command import ( "sort" "strings" + "sync" "github.com/hashicorp/vault/api" "github.com/posener/complete" ) +type Predict struct { + client *api.Client + clientOnce sync.Once +} + +func NewPredict() *Predict { + return &Predict{} +} + +func (p *Predict) Client() *api.Client { + p.clientOnce.Do(func() { + if p.client == nil { // For tests + p.client, _ = api.NewClient(nil) + } + }) + return p.client +} + // defaultPredictVaultMounts is the default list of mounts to return to the // user. This is a best-guess, given we haven't communicated with the Vault // server. If the user has no token or if the token does not have the default @@ -15,48 +34,63 @@ import ( // that returning nothing. var defaultPredictVaultMounts = []string{"cubbyhole/"} +// predictClient is the API client to use for prediction. We create this at the +// beginning once, because completions are generated for each command (and this +// doesn't change), and the only way to configure the predict/autocomplete +// client is via environment variables. Even if the user specifies a flag, we +// can't parse that flag until after the command is submitted. +var predictClient *api.Client +var predictClientOnce sync.Once + +// PredictClient returns the cached API client for the predictor. +func PredictClient() *api.Client { + predictClientOnce.Do(func() { + if predictClient == nil { // For tests + predictClient, _ = api.NewClient(nil) + } + }) + return predictClient +} + // PredictVaultFiles returns a predictor for Vault mounts and paths based on the // configured client for the base command. Unfortunately this happens pre-flag // parsing, so users must rely on environment variables for autocomplete if they // are not using Vault at the default endpoints. func (b *BaseCommand) PredictVaultFiles() complete.Predictor { - client, err := b.Client() - if err != nil { - return nil - } - return PredictVaultFiles(client) + return NewPredict().VaultFiles() } // PredictVaultFolders returns a predictor for "folders". See PredictVaultFiles // for more information and restrictions. func (b *BaseCommand) PredictVaultFolders() complete.Predictor { - client, err := b.Client() - if err != nil { - return nil - } - return PredictVaultFolders(client) + return NewPredict().VaultFolders() } -// PredictVaultFiles returns a predictor for Vault "files". This is a public API -// for consumers, but you probably want BaseCommand.PredictVaultFiles instead. -func PredictVaultFiles(client *api.Client) complete.Predictor { - return predictVaultPaths(client, true) +// VaultFiles returns a predictor for Vault "files". This is a public API for +// consumers, but you probably want BaseCommand.PredictVaultFiles instead. +func (p *Predict) VaultFiles() complete.Predictor { + return p.vaultPaths(true) } -// PredictVaultFolders returns a predictor for Vault "folders". This is a public +// VaultFolders returns a predictor for Vault "folders". This is a public // API for consumers, but you probably want BaseCommand.PredictVaultFolders // instead. -func PredictVaultFolders(client *api.Client) complete.Predictor { - return predictVaultPaths(client, false) +func (p *Predict) VaultFolders() complete.Predictor { + return p.vaultPaths(false) } -// predictVaultPaths parses the CLI options and returns the "best" list of -// possible paths. If there are any errors, this function returns an empty -// result. All errors are suppressed since this is a prediction function. -func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFunc { +// vaultPaths parses the CLI options and returns the "best" list of possible +// paths. If there are any errors, this function returns an empty result. All +// errors are suppressed since this is a prediction function. +func (p *Predict) vaultPaths(includeFiles bool) complete.PredictFunc { return func(args complete.Args) []string { // Do not predict more than one paths - if predictHasPathArg(args.All) { + if p.hasPathArg(args.All) { + return nil + } + + client := p.Client() + if client == nil { return nil } @@ -64,9 +98,9 @@ func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFu var predictions []string if strings.Contains(path, "/") { - predictions = predictPaths(client, path, includeFiles) + predictions = p.paths(path, includeFiles) } else { - predictions = predictMounts(client, path) + predictions = p.mounts(path) } // Either no results or many results, so return. @@ -87,14 +121,19 @@ func predictVaultPaths(client *api.Client, includeFiles bool) complete.PredictFu // Re-predict with the remaining path args.Last = predictions[0] - return predictVaultPaths(client, includeFiles).Predict(args) + return p.vaultPaths(includeFiles).Predict(args) } } -// predictMounts predicts all mounts which start with the given prefix. These -// are predicted on mount path, not "type". -func predictMounts(client *api.Client, path string) []string { - mounts := predictListMounts(client) +// mounts predicts all mounts which start with the given prefix. These are +// predicted on mount path, not "type". +func (p *Predict) mounts(path string) []string { + client := p.Client() + if client == nil { + return nil + } + + mounts := p.listMounts() var predictions []string for _, m := range mounts { @@ -106,8 +145,13 @@ func predictMounts(client *api.Client, path string) []string { return predictions } -// predictPaths predicts all paths which start with the given path. -func predictPaths(client *api.Client, path string, includeFiles bool) []string { +// paths predicts all paths which start with the given path. +func (p *Predict) paths(path string, includeFiles bool) []string { + client := p.Client() + if client == nil { + return nil + } + // Vault does not support listing based on a sub-key, so we have to back-pedal // to the last "/" and return all paths on that "folder". Then we perform // client-side filtering. @@ -117,7 +161,7 @@ func predictPaths(client *api.Client, path string, includeFiles bool) []string { root = root[:idx+1] } - paths := predictListPaths(client, root) + paths := p.listPaths(root) var predictions []string for _, p := range paths { @@ -140,11 +184,16 @@ func predictPaths(client *api.Client, path string, includeFiles bool) []string { return predictions } -// predictListMounts returns a sorted list of the mount paths for Vault server -// for which the client is configured to communicate with. This function returns -// the default list of mounts if an error occurs. -func predictListMounts(c *api.Client) []string { - mounts, err := c.Sys().ListMounts() +// listMounts returns a sorted list of the mount paths for Vault server for +// which the client is configured to communicate with. This function returns the +// default list of mounts if an error occurs. +func (p *Predict) listMounts() []string { + client := p.Client() + if client == nil { + return nil + } + + mounts, err := client.Sys().ListMounts() if err != nil { return defaultPredictVaultMounts } @@ -157,10 +206,15 @@ func predictListMounts(c *api.Client) []string { return list } -// predictListPaths returns a list of paths (HTTP LIST) for the given path. This +// listPaths returns a list of paths (HTTP LIST) for the given path. This // function returns an empty list of any errors occur. -func predictListPaths(c *api.Client, path string) []string { - secret, err := c.Logical().List(path) +func (p *Predict) listPaths(path string) []string { + client := p.Client() + if client == nil { + return nil + } + + secret, err := client.Logical().List(path) if err != nil || secret == nil || secret.Data == nil { return nil } @@ -180,8 +234,8 @@ func predictListPaths(c *api.Client, path string) []string { return list } -// predictHasPathArg determines if the args have already accepted a path. -func predictHasPathArg(args []string) bool { +// hasPathArg determines if the args have already accepted a path. +func (p *Predict) hasPathArg(args []string) bool { var nonFlags []string for _, a := range args { if !strings.HasPrefix(a, "-") { diff --git a/command/base_predict_test.go b/command/base_predict_test.go index 05af55da7c..3de6e5f0e7 100644 --- a/command/base_predict_test.go +++ b/command/base_predict_test.go @@ -189,7 +189,10 @@ func TestPredictVaultPaths(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - f := predictVaultPaths(client, tc.includeFiles) + p := NewPredict() + p.client = client + + f := p.vaultPaths(tc.includeFiles) act := f(tc.args) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) @@ -199,7 +202,7 @@ func TestPredictVaultPaths(t *testing.T) { }) } -func TestPredictMounts(t *testing.T) { +func TestPredict_Mounts(t *testing.T) { t.Parallel() client, closer := testVaultServer(t) @@ -233,7 +236,10 @@ func TestPredictMounts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - act := predictMounts(client, tc.path) + p := NewPredict() + p.client = client + + act := p.mounts(tc.path) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) } @@ -242,7 +248,7 @@ func TestPredictMounts(t *testing.T) { }) } -func TestPredictPaths(t *testing.T) { +func TestPredict_Paths(t *testing.T) { t.Parallel() client, closer := testVaultServer(t) @@ -303,7 +309,10 @@ func TestPredictPaths(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - act := predictPaths(client, tc.path, tc.includeFiles) + p := NewPredict() + p.client = client + + act := p.paths(tc.path, tc.includeFiles) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) } @@ -312,7 +321,7 @@ func TestPredictPaths(t *testing.T) { }) } -func TestPredictListMounts(t *testing.T) { +func TestPredict_ListMounts(t *testing.T) { t.Parallel() client, closer := testVaultServer(t) @@ -345,7 +354,10 @@ func TestPredictListMounts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - act := predictListMounts(tc.client) + p := NewPredict() + p.client = client + + act := p.listMounts() if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) } @@ -354,7 +366,7 @@ func TestPredictListMounts(t *testing.T) { }) } -func TestPredictListPaths(t *testing.T) { +func TestPredict_ListPaths(t *testing.T) { t.Parallel() client, closer := testVaultServer(t) @@ -394,7 +406,10 @@ func TestPredictListPaths(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - act := predictListPaths(tc.client, tc.path) + p := NewPredict() + p.client = client + + act := p.listPaths(tc.path) if !reflect.DeepEqual(act, tc.exp) { t.Errorf("expected %q to be %q", act, tc.exp) } @@ -403,7 +418,7 @@ func TestPredictListPaths(t *testing.T) { }) } -func TestPredictHasPathArg(t *testing.T) { +func TestPredict_HasPathArg(t *testing.T) { t.Parallel() cases := []struct { @@ -443,7 +458,8 @@ func TestPredictHasPathArg(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - if act := predictHasPathArg(tc.args); act != tc.exp { + p := NewPredict() + if act := p.hasPathArg(tc.args); act != tc.exp { t.Errorf("expected %t to be %t", act, tc.exp) } })