Merge master

This commit is contained in:
Yoko Hyakuna 2018-01-30 09:57:30 -08:00
commit 0b45ad6a15
1877 changed files with 165384 additions and 60790 deletions

View File

@ -1,9 +1,105 @@
## 0.9.2 (Unreleased)
## 0.9.3 (January 28th, 2018)
A regression from a feature merge disabled the Nomad secrets backend in 0.9.2.
This release re-enables the Nomad secrets backend; it is otherwise identical to
0.9.2.
## 0.9.2 (January 26th, 2018)
SECURITY:
* Okta Auth Backend: While the Okta auth backend was successfully verifying
usernames and passwords, it was not checking the returned state of the
account, so accounts that had been marked locked out could still be used to
log in. Only accounts in SUCCESS or PASSWORD_WARN states are now allowed.
* Periodic Tokens: A regression in 0.9.1 meant that periodic tokens created by
the AppRole, AWS, and Cert auth backends would expire when the max TTL for
the backend/mount/system was hit instead of their stated behavior of living
as long as they are renewed. This is now fixed; existing tokens do not have
to be reissued as this was purely a regression in the renewal logic.
* Seal Wrapping: During certain replication states values written marked for
seal wrapping may not be wrapped on the secondaries. This has been fixed,
and existing values will be wrapped on next read or write. This does not
affect the barrier keys.
DEPRECATIONS/CHANGES:
* `sys/health` DR Secondary Reporting: The `replication_dr_secondary` bool
returned by `sys/health` could be misleading since it would be `false` both
when a cluster was not a DR secondary but also when the node is a standby in
the cluster and has not yet fully received state from the active node. This
could cause health checks on LBs to decide that the node was acceptable for
traffic even though DR secondaries cannot handle normal Vault traffic. (In
other words, the bool could only convey "yes" or "no" but not "not sure
yet".) This has been replaced by `replication_dr_mode` and
`replication_perf_mode` which are string values that convey the current
state of the node; a value of `disabled` indicates that replication is
disabled or the state is still being discovered. As a result, an LB check
can positively verify that the node is both not `disabled` and is not a DR
secondary, and avoid sending traffic to it if either is true.
* PKI Secret Backend Roles parameter types: For `ou` and `organization`
in role definitions in the PKI secret backend, input can now be a
comma-separated string or an array of strings. Reading a role will
now return arrays for these parameters.
* Plugin API Changes: The plugin API has been updated to utilize golang's
context.Context package. Many function signatures now accept a context
object as the first parameter. Existing plugins will need to pull in the
latest Vault code and update their function signatures to begin using
context and the new gRPC transport.
FEATURES:
* **gRPC Backend Plugins**: Backend plugins now use gRPC for transport,
allowing them to be written in other languages.
* **Brand New CLI**: Vault has a brand new CLI interface that is significantly
streamlined, supports autocomplete, and is almost entirely backwards
compatible.
* **UI: PKI Secret Backend (Enterprise)**: Configure PKI secret backends,
create and browse roles and certificates, and issue and sign certificates via
the listed roles.
IMPROVEMENTS:
* auth/aws: Handle IAM headers produced by clients that formulate numbers as
ints rather than strings [GH-3763]
* auth/okta: Support JSON lists when specifying groups and policies [GH-3801]
* autoseal/hsm: Attempt reconnecting to the HSM on certain kinds of issues,
including HA scenarios for some Gemalto HSMs.
(Enterprise)
* cli: Output password prompts to stderr to make it easier to pipe an output
token to another command [GH-3782]
* core: Report replication status in `sys/health` [GH-3810]
* physical/s3: Allow using paths with S3 for non-AWS deployments [GH-3730]
* physical/s3: Add ability to disable SSL for non-AWS deployments [GH-3730]
* plugins: Args for plugins can now be specified separately from the command,
allowing the same output format and input format for plugin information
[GH-3778]
* secret/pki: `ou` and `organization` can now be specified as a
comma-separated string or an array of strings [GH-3804]
* plugins: Plugins will fall back to using netrpc as the communication protocol
on older versions of Vault [GH-3833]
BUG FIXES:
* auth/(approle,aws,cert): Fix behavior where periodic tokens generated by
these backends could not have their TTL renewed beyond the system/mount max
TTL value [GH-3803]
* auth/aws: Fix error returned if `bound_iam_principal_arn` was given to an
existing role update [GH-3843]
* core/sealwrap: Speed improvements and bug fixes (Enterprise)
* identity: Delete group alias when an external group is deleted [GH-3773]
* legacymfa/duo: Fix intermittent panic when Duo could not be reached
[GH-2030]
* secret/database: Fix a location where a lock could potentially not be
released, leading to deadlock [GH-3774]
* secret/(all databases) Fix behavior where if a max TTL was specified but no
default TTL was specified the system/mount default TTL would be used but not
be capped by the local max TTL [GH-3814]
* secret/database: Fix an issue where plugins were not closed properly if they
failed to initialize [GH-3768]
* ui: mounting a secret backend will now properly set `max_lease_ttl` and
`default_lease_ttl` when specified - previously both fields set
`default_lease_ttl`.
## 0.9.1 (December 21st, 2017)
@ -291,14 +387,14 @@ IMPROVEMENTS:
(PID) in a file [GH-3321]
* mfa (Enterprise): Add the ability to use identity metadata in username format
* mfa/okta (Enterprise): Add support for configuring base_url for API calls
* secret/pki: `sign-intermediate` will now allow specifying a `ttl` value
* secret/pki: `sign-intermediate` will now allow specifying a `ttl` value
longer than the signing CA certificate's NotAfter value. [GH-3325]
* sys/raw: Raw storage access is now disabled by default [GH-3329]
BUG FIXES:
* auth/okta: Fix regression that removed the ability to set base_url [GH-3313]
* core: Fix panic while loading leases at startup on ARM processors
* core: Fix panic while loading leases at startup on ARM processors
[GH-3314]
* secret/pki: Fix `sign-self-issued` encoding the wrong subject public key
[GH-3325]
@ -348,7 +444,7 @@ IMPROVEMENTS:
* auth/okta: Compare groups case-insensitively since Okta is only
case-preserving [GH-3240]
* auth/okta: Standardize Okta configuration APIs across backends [GH-3245]
* cli: Add subcommand autocompletion that can be enabled with
* cli: Add subcommand autocompletion that can be enabled with
`vault -autocomplete-install` [GH-3223]
* cli: Add ability to handle wrapped responses when using `vault auth`. What
is output depends on the other given flags; see the help output for that

View File

@ -31,7 +31,12 @@ dev-dynamic: prep
# test runs the unit tests and vets the code
test: prep
CGO_ENABLED=0 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=20m -parallel=4
@CGO_ENABLED=0 \
VAULT_ADDR= \
VAULT_TOKEN= \
VAULT_DEV_ROOT_TOKEN_ID= \
VAULT_ACC= \
go test -tags='$(BUILD_TAGS)' $(TEST) $(TESTARGS) -timeout=20m -parallel=20
testcompile: prep
@for pkg in $(TEST) ; do \
@ -48,7 +53,12 @@ testacc: prep
# testrace runs the race checker
testrace: prep
CGO_ENABLED=1 VAULT_TOKEN= VAULT_ACC= go test -tags='$(BUILD_TAGS)' -race $(TEST) $(TESTARGS) -timeout=45m -parallel=4
@CGO_ENABLED=1 \
VAULT_ADDR= \
VAULT_TOKEN= \
VAULT_DEV_ROOT_TOKEN_ID= \
VAULT_ACC= \
go test -tags='$(BUILD_TAGS)' -race $(TEST) $(TESTARGS) -timeout=45m -parallel=20
cover:
./scripts/coverage.sh --html
@ -85,7 +95,8 @@ proto:
protoc -I physical physical/types.proto --go_out=plugins=grpc:physical
protoc -I helper/identity -I ../../.. helper/identity/types.proto --go_out=plugins=grpc:helper/identity
protoc builtin/logical/database/dbplugin/*.proto --go_out=plugins=grpc:.
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go
protoc logical/plugin/pb/*.proto --go_out=plugins=grpc:.
sed -i -e 's/Idp/IDP/' -e 's/Url/URL/' -e 's/Id/ID/' -e 's/EntityId/EntityID/' -e 's/Api/API/' -e 's/Qr/QR/' -e 's/protobuf:"/sentinel:"" protobuf:"/' helper/identity/types.pb.go helper/storagepacker/types.pb.go logical/plugin/pb/backend.pb.go
sed -i -e 's/Iv/IV/' -e 's/Hmac/HMAC/' physical/types.pb.go
fmtcheck:

View File

@ -102,9 +102,9 @@ $ make test TEST=./vault
### Acceptance Tests
Vault has comprehensive [acceptance tests](https://en.wikipedia.org/wiki/Acceptance_testing)
covering most of the features of the secret and auth backends.
covering most of the features of the secret and auth methods.
If you're working on a feature of a secret or auth backend and want to
If you're working on a feature of a secret or auth method and want to
verify it is functioning (and also hasn't broken anything else), we recommend
running the acceptance tests.

View File

@ -1,611 +0,0 @@
FORMAT: 1A
# vault
The Vault API gives you full access to the Vault project.
If you're browsing this API specifiction in GitHub or in raw
format, please excuse some of the odd formatting. This document
is in api-blueprint format that is read by viewers such as
Apiary.
## Sealed vs. Unsealed
Whenever an individual Vault server is started, it is started
in the _sealed_ state. In this state, it knows where its data
is located, but the data is encrypted and Vault doesn't have the
encryption keys to access it. Before Vault can operate, it must
be _unsealed_.
**Note:** Sealing/unsealing has no relationship to _authentication_
which is separate and still required once the Vault is unsealed.
Instead of being sealed with a single key, we utilize
[Shamir's Secret Sharing](http://en.wikipedia.org/wiki/Shamir%27s_Secret_Sharing)
to shard a key into _n_ parts such that _t_ parts are required
to reconstruct the original key, where `t <= n`. This means that
Vault itself doesn't know the original key, and no single person
has the original key (unless `n = 1`, or `t` parts are given to
a single person).
Unsealing is done via an unauthenticated
[unseal API](#reference/seal/unseal/unseal). This API takes a single
master shard and progresses the unsealing process. Once all shards
are given, the Vault is either unsealed or resets the unsealing
process if the key was invalid.
The entire seal/unseal state is server-wide. This allows multiple
distinct operators to use the unseal API (or more likely the
`vault unseal` command) from separate computers/networks and never
have to transmit their key in order to unseal the vault in a
distributed fashion.
## Transport
The API is expected to be accessed over a TLS connection at
all times, with a valid certificate that is verified by a well
behaved client.
## Authentication
Once the Vault is unsealed, every other operation requires
authentication. There are multiple methods for authentication
that can be enabled (see
[authentication](#reference/authentication)).
Authentication is done with the login endpoint. The login endpoint
returns an access token that is set as the `X-Vault-Token` header.
## Help
To retrieve the help for any API within Vault, including mounted
backends, credential providers, etc. then append `?help=1` to any
URL. If you have valid permission to access the path, then the help text
will be returned with the following structure:
{
"help": "help text"
}
## Error Response
A common JSON structure is always returned to return errors:
{
"errors": [
"message",
"another message"
]
}
This structure will be sent down for any non-20x HTTP status.
## HTTP Status Codes
The following HTTP status codes are used throughout the API.
- `200` - Success with data.
- `204` - Success, no data returned.
- `400` - Invalid request, missing or invalid data.
- `403` - Forbidden, your authentication details are either
incorrect or you don't have access to this feature.
- `404` - Invalid path. This can both mean that the path truly
doesn't exist or that you don't have permission to view a
specific path. We use 404 in some cases to avoid state leakage.
- `429` - Rate limit exceeded. Try again after waiting some period
of time.
- `500` - Internal server error. An internal error has occurred,
try again later. If the error persists, report a bug.
- `503` - Vault is down for maintenance or is currently sealed.
Try again later.
# Group Initialization
## Initialization [/sys/init]
### Initialization Status [GET]
Returns the status of whether the vault is initialized or not. The
vault doesn't have to be unsealed for this operation.
+ Response 200 (application/json)
{
"initialized": true
}
### Initialize [POST]
Initialize the vault. This is an unauthenticated request to initially
setup a new vault. Although this is unauthenticated, it is still safe:
data cannot be in vault prior to initialization, and any future
authentication will fail if you didn't initialize it yourself.
Additionally, once initialized, a vault cannot be reinitialized.
This API is the only time Vault will ever be aware of your keys, and
the only time the keys will ever be returned in one unit. Care should
be taken to ensure that the output of this request is never logged,
and that the keys are properly distributed.
The response also contains the initial root token that can be used
as authentication in order to initially configure Vault once it is
unsealed. Just as with the unseal keys, this is the only time Vault is
ever aware of this token.
+ Request (application/json)
{
"secret_shares": 5,
"secret_threshold": 3,
}
+ Response 200 (application/json)
{
"keys": ["one", "two", "three"],
"root_token": "foo"
}
# Group Seal/Unseal
## Seal Status [/sys/seal-status]
### Seal Status [GET]
Returns the status of whether the vault is currently
sealed or not, as well as the progress of unsealing.
The response has the following attributes:
- sealed (boolean) - If true, the vault is sealed. Otherwise,
it is unsealed.
- t (int) - The "t" value for the master key, or the number
of shards needed total to unseal the vault.
- n (int) - The "n" value for the master key, or the total
number of shards of the key distributed.
- progress (int) - The number of master key shards that have
been entered so far towards unsealing the vault.
+ Response 200 (application/json)
{
"sealed": true,
"t": 3,
"n": 5,
"progress": 1
}
## Seal [/sys/seal]
### Seal [PUT]
Seal the vault.
Sealing the vault locks Vault from any future operations on any
secrets or system configuration until the vault is once again
unsealed. Internally, sealing throws away the keys to access the
encrypted vault data, so Vault is unable to access the data without
unsealing to get the encryption keys.
+ Response 204
## Unseal [/sys/unseal]
### Unseal [PUT]
Unseal the vault.
Unseal the vault by entering a portion of the master key. The
response object will tell you if the unseal is complete or
only partial.
If the vault is already unsealed, this does nothing. It is
not an error, the return value just says the vault is unsealed.
Due to the architecture of Vault, we cannot validate whether
any portion of the unseal key given is valid until all keys
are inputted, therefore unsealing an already unsealed vault
is still a success even if the input key is invalid.
+ Request (application/json)
{
"key": "value"
}
+ Response 200 (application/json)
{
"sealed": true,
"t": 3,
"n": 5,
"progress": 1
}
# Group Authentication
## List Auth Methods [/sys/auth]
### List all auth methods [GET]
Lists all available authentication methods.
This returns the name of the authentication method as well as
a human-friendly long-form help text for the method that can be
shown to the user as documentation.
+ Response 200 (application/json)
{
"token": {
"type": "token",
"description": "Token authentication"
},
"oauth": {
"type": "oauth",
"description": "OAuth authentication"
}
}
## Single Auth Method [/sys/auth/{id}]
+ Parameters
+ id (required, string) ... The ID of the auth method.
### Enable an auth method [PUT]
Enables an authentication method.
The body of the request depends on the authentication method
being used. Please reference the documentation for the specific
authentication method you're enabling in order to determine what
parameters you must give it.
If an authentication method is already enabled, then this can be
used to change the configuration, including even the type of
the configuration.
+ Request (application/json)
{
"type": "type",
"key": "value",
"key2": "value2"
}
+ Response 204
### Disable an auth method [DELETE]
Disables an authentication method. Previously authenticated sessions
are immediately invalidated.
+ Response 204
# Group Policies
Policies are named permission sets that identities returned by
credential stores are bound to. This separates _authentication_
from _authorization_.
## Policies [/sys/policy]
### List all Policies [GET]
List all the policies.
+ Response 200 (application/json)
{
"policies": ["root"]
}
## Single Policy [/sys/policy/{id}]
+ Parameters
+ id (required, string) ... The name of the policy
### Upsert [PUT]
Create or update a policy with the given ID.
+ Request (application/json)
{
"rules": "HCL"
}
+ Response 204
### Delete [DELETE]
Delete a policy with the given ID. Any identities bound to this
policy will immediately become "deny all" despite already being
authenticated.
+ Response 204
# Group Mounts
Logical backends are mounted at _mount points_, similar to
filesystems. This allows you to mount the "aws" logical backend
at the "aws-us-east" path, so all access is at `/aws-us-east/keys/foo`
for example. This enables multiple logical backends to be enabled.
## Mounts [/sys/mounts]
### List all mounts [GET]
Lists all the active mount points.
+ Response 200 (application/json)
{
"aws": {
"type": "aws",
"description": "AWS"
},
"pg": {
"type": "postgresql",
"description": "PostgreSQL dynamic users"
}
}
## Single Mount [/sys/mounts/{path}]
### New Mount [POST]
Mount a logical backend to a new path.
Configuration for this new backend is done via the normal
read/write mechanism once it is mounted.
+ Request (application/json)
{
"type": "aws",
"description": "EU AWS tokens"
}
+ Response 204
### Unmount [DELETE]
Unmount a mount point.
+ Response 204
## Remount [/sys/remount]
### Remount [POST]
Move an already-mounted backend to a new path.
+ Request (application/json)
{
"from": "aws",
"to": "aws-east"
}
+ Response 204
# Group Audit Backends
Audit backends are responsible for shuttling the audit logs that
Vault generates to a durable system for future querying. By default,
audit logs are not stored anywhere.
## Audit Backends [/sys/audit]
### List Enabled Audit Backends [GET]
List all the enabled audit backends
+ Response 200 (application/json)
{
"file": {
"type": "file",
"description": "Send audit logs to a file",
"options": {}
}
}
## Single Audit Backend [/sys/audit/{path}]
+ Parameters
+ path (required, string) ... The path where the audit backend is mounted
### Enable [PUT]
Enable an audit backend.
+ Request (application/json)
{
"type": "file",
"description": "send to a file",
"options": {
"path": "/var/log/vault.audit.log"
}
}
+ Response 204
### Disable [DELETE]
Disable an audit backend.
+ Request (application/json)
+ Response 204
# Group Secrets
## Generic [/{mount}/{path}]
This group documents the general format of reading and writing
to Vault. The exact structure of the keyspace is defined by the
logical backends in use, so documentation related to
a specific backend should be referenced for details on what keys
and routes are expected.
The path for examples are `/prefix/path`, but in practice
these will be defined by the backends that are mounted. For
example, reading an AWS key might be at the `/aws/root` path.
These paths are defined by the logical backends.
+ Parameters
+ mount (required, string) ... The mount point for the
logical backend. Example: `aws`.
+ path (optional, string) ... The path within the backend
to read or write data.
### Read [GET]
Read data from vault.
The data read from the vault can either be a secret or
arbitrary configuration data. The type of data returned
depends on the path, and is defined by the logical backend.
If the return value is a secret, then the return structure
is a mixture of arbitrary key/value along with the following
fields which are guaranteed to exist:
- `lease_id` (string) - A unique ID used for renewal and
revocation.
- `renewable` (bool) - If true, then this key can be renewed.
If a key can't be renewed, then a new key must be requested
after the lease duration period.
- `lease_duration` (int) - The time in seconds that a secret is
valid for before it must be renewed.
- `lease_duration_max` (int) - The maximum amount of time in
seconds that a secret is valid for. This will always be
greater than or equal to `lease_duration`. The difference
between this and `lease_duration` is an overlap window
where multiple keys may be valid.
If the return value is not a secret, then the return structure
is an arbitrary JSON object.
+ Response 200 (application/json)
{
"lease_id": "UUID",
"lease_duration": 3600,
"key": "value"
}
### Write [PUT]
Write data to vault.
The behavior and arguments to the write are defined by
the logical backend.
+ Request (application/json)
{
"key": "value"
}
+ Response 204
# Group Lease Management
## Renew Key [/sys/renew/{id}]
+ Parameters
+ id (required, string) ... The `lease_id` of the secret
to renew.
### Renew [PUT]
+ Response 200 (application/json)
{
"lease_id": "...",
"lease_duration": 3600,
"access_key": "foo",
"secret_key": "bar"
}
## Revoke Key [/sys/revoke/{id}]
+ Parameters
+ id (required, string) ... The `lease_id` of the secret
to revoke.
### Revoke [PUT]
+ Response 204
# Group Backend: AWS
## Root Key [/aws/root]
### Set the Key [PUT]
Set the root key that the logical backend will use to create
new secrets, IAM policies, etc.
+ Request (application/json)
{
"access_key": "key",
"secret_key": "key",
"region": "us-east-1"
}
+ Response 204
## Policies [/aws/policies]
### List Policies [GET]
List all the policies that can be used to create keys.
+ Response 200 (application/json)
[{
"name": "root",
"description": "Root access"
}, {
"name": "web-deploy",
"description": "Enough permissions to deploy the web app."
}]
## Single Policy [/aws/policies/{name}]
+ Parameters
+ name (required, string) ... Name of the policy.
### Read [GET]
Read a policy.
+ Response 200 (application/json)
{
"policy": "base64-encoded policy"
}
### Upsert [PUT]
Create or update a policy.
+ Request (application/json)
{
"policy": "base64-encoded policy"
}
+ Response 204
### Delete [DELETE]
Delete the policy with the given name.
+ Response 204
## Generate Access Keys [/aws/keys/{policy}]
### Create [GET]
This generates a new keypair for the given policy.
+ Parameters
+ policy (required, string) ... The policy under which to create
the key pair.
+ Response 200 (application/json)
{
"lease_id": "...",
"lease_duration": 3600,
"access_key": "foo",
"secret_key": "bar"
}

View File

@ -1,59 +1,131 @@
package api_test
import (
"context"
"database/sql"
"encoding/base64"
"fmt"
"net"
"net/http"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/audit"
"github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/logical/pki"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
auditFile "github.com/hashicorp/vault/builtin/audit/file"
credUserpass "github.com/hashicorp/vault/builtin/credential/userpass"
vaulthttp "github.com/hashicorp/vault/http"
logxi "github.com/mgutz/logxi/v1"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
var testVaultServerDefaultBackends = map[string]logical.Factory{
"transit": transit.Factory,
"pki": pki.Factory,
}
// testVaultServer creates a test vault cluster and returns a configured API
// client and closer function.
func testVaultServer(t testing.TB) (*api.Client, func()) {
return testVaultServerBackends(t, testVaultServerDefaultBackends)
t.Helper()
client, _, closer := testVaultServerUnseal(t)
return client, closer
}
func testVaultServerBackends(t testing.TB, backends map[string]logical.Factory) (*api.Client, func()) {
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logxi.NullLog,
LogicalBackends: backends,
}
// testVaultServerUnseal creates a test vault cluster and returns a configured
// API client, list of unseal keys (as strings), and a closer function.
func testVaultServerUnseal(t testing.TB) (*api.Client, []string, func()) {
t.Helper()
return testVaultServerCoreConfig(t, &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logxi.NullLog,
CredentialBackends: map[string]logical.Factory{
"userpass": credUserpass.Factory,
},
AuditBackends: map[string]audit.Factory{
"file": auditFile.Factory,
},
LogicalBackends: map[string]logical.Factory{
"database": database.Factory,
"generic-leased": vault.LeasedPassthroughBackendFactory,
"pki": pki.Factory,
"transit": transit.Factory,
},
})
}
// testVaultServerCoreConfig creates a new vault cluster with the given core
// configuration. This is a lower-level test helper.
func testVaultServerCoreConfig(t testing.TB, coreConfig *vault.CoreConfig) (*api.Client, []string, func()) {
t.Helper()
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
// make it easy to get access to the active
// Make it easy to get access to the active
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)
// Get the client already setup for us!
client := cluster.Cores[0].Client
client.SetToken(cluster.RootToken)
// Sanity check
secret, err := client.Auth().Token().LookupSelf()
// Convert the unseal keys to base64 encoded, since these are how the user
// will get them.
unsealKeys := make([]string, len(cluster.BarrierKeys))
for i := range unsealKeys {
unsealKeys[i] = base64.StdEncoding.EncodeToString(cluster.BarrierKeys[i])
}
return client, unsealKeys, func() { defer cluster.Cleanup() }
}
// testVaultServerBad creates an http server that returns a 500 on each request
// to simulate failures.
func testVaultServerBad(t testing.TB) (*api.Client, func()) {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
if secret == nil || secret.Data["id"].(string) != cluster.RootToken {
t.Fatalf("token mismatch: %#v vs %q", secret, cluster.RootToken)
server := &http.Server{
Addr: "127.0.0.1:0",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "500 internal server error", http.StatusInternalServerError)
}),
ReadTimeout: 1 * time.Second,
ReadHeaderTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
IdleTimeout: 1 * time.Second,
}
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
t.Fatal(err)
}
}()
client, err := api.NewClient(&api.Config{
Address: "http://" + listener.Addr().String(),
})
if err != nil {
t.Fatal(err)
}
return client, func() {
ctx, done := context.WithTimeout(context.Background(), 5*time.Second)
defer done()
server.Shutdown(ctx)
}
return client, func() { defer cluster.Cleanup() }
}
// testPostgresDB creates a testing postgres database in a Docker container,

View File

@ -12,6 +12,7 @@ import (
"strings"
"sync"
"time"
"unicode"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-cleanhttp"
@ -530,8 +531,17 @@ func (c *Client) RawRequest(r *Request) (*Response, error) {
c.modifyLock.RLock()
c.config.modifyLock.RLock()
defer c.config.modifyLock.RUnlock()
token := c.token
c.modifyLock.RUnlock()
// Sanity check the token before potentially erroring from the API
idx := strings.IndexFunc(token, func(c rune) bool {
return !unicode.IsPrint(c)
})
if idx != -1 {
return nil, fmt.Errorf("Configured Vault token contains non-printable characters and cannot be used.")
}
redirectCount := 0
START:
req, err := r.ToHTTP()

View File

@ -5,6 +5,7 @@ import (
"io"
"net/http"
"os"
"strings"
"testing"
"time"
)
@ -95,6 +96,30 @@ func TestClientToken(t *testing.T) {
}
}
func TestClientBadToken(t *testing.T) {
handler := func(w http.ResponseWriter, req *http.Request) {}
config, ln := testHTTPServer(t, http.HandlerFunc(handler))
defer ln.Close()
client, err := NewClient(config)
if err != nil {
t.Fatalf("err: %s", err)
}
client.SetToken("foo")
_, err = client.RawRequest(client.NewRequest("PUT", "/"))
if err != nil {
t.Fatal(err)
}
client.SetToken("foo\u007f")
_, err = client.RawRequest(client.NewRequest("PUT", "/"))
if err == nil || !strings.Contains(err.Error(), "printable") {
t.Fatalf("expected error due to bad token")
}
}
func TestClientRedirect(t *testing.T) {
primary := func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("test"))

View File

@ -5,20 +5,12 @@ import (
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/builtin/logical/database"
"github.com/hashicorp/vault/builtin/logical/pki"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/logical"
)
func TestRenewer_Renew(t *testing.T) {
t.Parallel()
client, vaultDone := testVaultServerBackends(t, map[string]logical.Factory{
"database": database.Factory,
"pki": pki.Factory,
"transit": transit.Factory,
})
client, vaultDone := testVaultServer(t)
defer vaultDone()
pgURL, pgDone := testPostgresDB(t)

View File

@ -1,10 +1,12 @@
package api
import (
"fmt"
"io"
"time"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/parseutil"
)
// Secret is the structure returned for every secret within Vault.
@ -35,6 +37,188 @@ type Secret struct {
WrapInfo *SecretWrapInfo `json:"wrap_info,omitempty"`
}
// TokenID returns the standardized token ID (token) for the given secret.
func (s *Secret) TokenID() (string, error) {
if s == nil {
return "", nil
}
if s.Auth != nil && len(s.Auth.ClientToken) > 0 {
return s.Auth.ClientToken, nil
}
if s.Data == nil || s.Data["id"] == nil {
return "", nil
}
id, ok := s.Data["id"].(string)
if !ok {
return "", fmt.Errorf("token found but in the wrong format")
}
return id, nil
}
// TokenAccessor returns the standardized token accessor for the given secret.
// If the secret is nil or does not contain an accessor, this returns the empty
// string.
func (s *Secret) TokenAccessor() (string, error) {
if s == nil {
return "", nil
}
if s.Auth != nil && len(s.Auth.Accessor) > 0 {
return s.Auth.Accessor, nil
}
if s.Data == nil || s.Data["accessor"] == nil {
return "", nil
}
accessor, ok := s.Data["accessor"].(string)
if !ok {
return "", fmt.Errorf("token found but in the wrong format")
}
return accessor, nil
}
// TokenRemainingUses returns the standardized remaining uses for the given
// secret. If the secret is nil or does not contain the "num_uses", this
// returns -1. On error, this will return -1 and a non-nil error.
func (s *Secret) TokenRemainingUses() (int, error) {
if s == nil || s.Data == nil || s.Data["num_uses"] == nil {
return -1, nil
}
uses, err := parseutil.ParseInt(s.Data["num_uses"])
if err != nil {
return 0, err
}
return int(uses), nil
}
// TokenPolicies returns the standardized list of policies for the given secret.
// If the secret is nil or does not contain any policies, this returns nil.
func (s *Secret) TokenPolicies() ([]string, error) {
if s == nil {
return nil, nil
}
if s.Auth != nil && len(s.Auth.Policies) > 0 {
return s.Auth.Policies, nil
}
if s.Data == nil || s.Data["policies"] == nil {
return nil, nil
}
sList, ok := s.Data["policies"].([]string)
if ok {
return sList, nil
}
list, ok := s.Data["policies"].([]interface{})
if !ok {
return nil, fmt.Errorf("unable to convert token policies to expected format")
}
policies := make([]string, len(list))
for i := range list {
p, ok := list[i].(string)
if !ok {
return nil, fmt.Errorf("unable to convert policy %v to string", list[i])
}
policies[i] = p
}
return policies, nil
}
// TokenMetadata returns the map of metadata associated with this token, if any
// exists. If the secret is nil or does not contain the "metadata" key, this
// returns nil.
func (s *Secret) TokenMetadata() (map[string]string, error) {
if s == nil {
return nil, nil
}
if s.Auth != nil && len(s.Auth.Metadata) > 0 {
return s.Auth.Metadata, nil
}
if s.Data == nil || (s.Data["metadata"] == nil && s.Data["meta"] == nil) {
return nil, nil
}
data, ok := s.Data["metadata"].(map[string]interface{})
if !ok {
data, ok = s.Data["meta"].(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unable to convert metadata field to expected format")
}
}
metadata := make(map[string]string, len(data))
for k, v := range data {
typed, ok := v.(string)
if !ok {
return nil, fmt.Errorf("unable to convert metadata value %v to string", v)
}
metadata[k] = typed
}
return metadata, nil
}
// TokenIsRenewable returns the standardized token renewability for the given
// secret. If the secret is nil or does not contain the "renewable" key, this
// returns false.
func (s *Secret) TokenIsRenewable() (bool, error) {
if s == nil {
return false, nil
}
if s.Auth != nil && s.Auth.Renewable {
return s.Auth.Renewable, nil
}
if s.Data == nil || s.Data["renewable"] == nil {
return false, nil
}
renewable, err := parseutil.ParseBool(s.Data["renewable"])
if err != nil {
return false, fmt.Errorf("could not convert renewable value to a boolean: %v", err)
}
return renewable, nil
}
// TokenTTL returns the standardized remaining token TTL for the given secret.
// If the secret is nil or does not contain a TTL, this returns 0.
func (s *Secret) TokenTTL() (time.Duration, error) {
if s == nil {
return 0, nil
}
if s.Auth != nil && s.Auth.LeaseDuration > 0 {
return time.Duration(s.Auth.LeaseDuration) * time.Second, nil
}
if s.Data == nil || s.Data["ttl"] == nil {
return 0, nil
}
ttl, err := parseutil.ParseDurationSecond(s.Data["ttl"])
if err != nil {
return 0, err
}
return ttl, nil
}
// SecretWrapInfo contains wrapping information if we have it. If what is
// contained is an authentication token, the accessor for the token will be
// available in WrappedAccessor.

File diff suppressed because it is too large Load Diff

View File

@ -5,8 +5,10 @@ func (c *Sys) Health() (*HealthResponse, error) {
// If the code is 400 or above it will automatically turn into an error,
// but the sys/health API defaults to returning 5xx when not sealed or
// inited, so we force this code to be something else so we parse correctly
r.Params.Add("sealedcode", "299")
r.Params.Add("uninitcode", "299")
r.Params.Add("sealedcode", "299")
r.Params.Add("standbycode", "299")
r.Params.Add("drsecondarycode", "299")
resp, err := c.c.RawRequest(r)
if err != nil {
return nil, err
@ -19,11 +21,13 @@ func (c *Sys) Health() (*HealthResponse, error) {
}
type HealthResponse struct {
Initialized bool `json:"initialized"`
Sealed bool `json:"sealed"`
Standby bool `json:"standby"`
ServerTimeUTC int64 `json:"server_time_utc"`
Version string `json:"version"`
ClusterName string `json:"cluster_name,omitempty"`
ClusterID string `json:"cluster_id,omitempty"`
Initialized bool `json:"initialized"`
Sealed bool `json:"sealed"`
Standby bool `json:"standby"`
ReplicationPerformanceMode string `json:"replication_performance_mode"`
ReplicationDRMode string `json:"replication_dr_mode"`
ServerTimeUTC int64 `json:"server_time_utc"`
Version string `json:"version"`
ClusterName string `json:"cluster_name,omitempty"`
ClusterID string `json:"cluster_id,omitempty"`
}

View File

@ -1,6 +1,8 @@
package audit
import (
"context"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
)
@ -14,13 +16,13 @@ type Backend interface {
// request is authorized but before the request is executed. The arguments
// MUST not be modified in anyway. They should be deep copied if this is
// a possibility.
LogRequest(*logical.Auth, *logical.Request, error) error
LogRequest(context.Context, *logical.Auth, *logical.Request, error) error
// LogResponse is used to synchronously log a response. This is done after
// the request is processed but before the response is sent. The arguments
// MUST not be modified in anyway. They should be deep copied if this is
// a possibility.
LogResponse(*logical.Auth, *logical.Request, *logical.Response, error) error
LogResponse(context.Context, *logical.Auth, *logical.Request, *logical.Response, error) error
// GetHash is used to return the given data with the backend's hash,
// so that a caller can determine if a value in the audit log matches
@ -28,10 +30,10 @@ type Backend interface {
GetHash(string) (string, error)
// Reload is called on SIGHUP for supporting backends.
Reload() error
Reload(context.Context) error
// Invalidate is called for path invalidation
Invalidate()
Invalidate(context.Context)
}
type BackendConfig struct {
@ -46,4 +48,4 @@ type BackendConfig struct {
}
// Factory is the factory function to create an audit backend.
type Factory func(*BackendConfig) (Backend, error)
type Factory func(context.Context, *BackendConfig) (Backend, error)

View File

@ -1,6 +1,7 @@
package audit
import (
"context"
"crypto/sha256"
"fmt"
"reflect"
@ -94,7 +95,7 @@ func TestCopy_response(t *testing.T) {
func TestHashString(t *testing.T) {
inmemStorage := &logical.InmemStorage{}
inmemStorage.Put(&logical.StorageEntry{
inmemStorage.Put(context.Background(), &logical.StorageEntry{
Key: "salt",
Value: []byte("foo"),
})
@ -192,7 +193,7 @@ func TestHash(t *testing.T) {
}
inmemStorage := &logical.InmemStorage{}
inmemStorage.Put(&logical.StorageEntry{
inmemStorage.Put(context.Background(), &logical.StorageEntry{
Key: "salt",
Value: []byte("foo"),
})

View File

@ -1,6 +1,7 @@
package file
import (
"context"
"fmt"
"io/ioutil"
"os"
@ -14,7 +15,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config")
}
@ -168,7 +169,12 @@ func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(salt, data), nil
}
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
func (b *Backend) LogRequest(
_ context.Context,
auth *logical.Auth,
req *logical.Request,
outerErr error) error {
b.fileLock.Lock()
defer b.fileLock.Unlock()
@ -199,6 +205,7 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
}
func (b *Backend) LogResponse(
_ context.Context,
auth *logical.Auth,
req *logical.Request,
resp *logical.Response,
@ -264,7 +271,7 @@ func (b *Backend) open() error {
return nil
}
func (b *Backend) Reload() error {
func (b *Backend) Reload(_ context.Context) error {
switch b.path {
case "stdout", "discard":
return nil
@ -288,7 +295,7 @@ func (b *Backend) Reload() error {
return b.open()
}
func (b *Backend) Invalidate() {
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil

View File

@ -1,6 +1,7 @@
package file
import (
"context"
"io/ioutil"
"os"
"path/filepath"
@ -33,7 +34,7 @@ func TestAuditFile_fileModeNew(t *testing.T) {
"mode": modeStr,
}
_, err = Factory(&audit.BackendConfig{
_, err = Factory(context.Background(), &audit.BackendConfig{
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},
Config: config,
@ -72,7 +73,7 @@ func TestAuditFile_fileModeExisting(t *testing.T) {
"path": f.Name(),
}
_, err = Factory(&audit.BackendConfig{
_, err = Factory(context.Background(), &audit.BackendConfig{
Config: config,
SaltConfig: &salt.Config{},
SaltView: &logical.InmemStorage{},

View File

@ -2,6 +2,7 @@ package socket
import (
"bytes"
"context"
"fmt"
"net"
"strconv"
@ -15,7 +16,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config")
}
@ -128,7 +129,7 @@ func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(salt, data), nil
}
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
func (b *Backend) LogRequest(ctx context.Context, auth *logical.Auth, req *logical.Request, outerErr error) error {
var buf bytes.Buffer
if err := b.formatter.FormatRequest(&buf, b.formatConfig, auth, req, outerErr); err != nil {
return err
@ -137,21 +138,21 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
b.Lock()
defer b.Unlock()
err := b.write(buf.Bytes())
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect()
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(buf.Bytes())
err = b.write(ctx, buf.Bytes())
}
}
return err
}
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
func (b *Backend) LogResponse(ctx context.Context, auth *logical.Auth, req *logical.Request,
resp *logical.Response, outerErr error) error {
var buf bytes.Buffer
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, outerErr); err != nil {
@ -161,23 +162,23 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request,
b.Lock()
defer b.Unlock()
err := b.write(buf.Bytes())
err := b.write(ctx, buf.Bytes())
if err != nil {
rErr := b.reconnect()
rErr := b.reconnect(ctx)
if rErr != nil {
err = multierror.Append(err, rErr)
} else {
// Try once more after reconnecting
err = b.write(buf.Bytes())
err = b.write(ctx, buf.Bytes())
}
}
return err
}
func (b *Backend) write(buf []byte) error {
func (b *Backend) write(ctx context.Context, buf []byte) error {
if b.connection == nil {
if err := b.reconnect(); err != nil {
if err := b.reconnect(ctx); err != nil {
return err
}
}
@ -195,13 +196,14 @@ func (b *Backend) write(buf []byte) error {
return err
}
func (b *Backend) reconnect() error {
func (b *Backend) reconnect(ctx context.Context) error {
if b.connection != nil {
b.connection.Close()
b.connection = nil
}
conn, err := net.Dial(b.socketType, b.address)
dialer := net.Dialer{}
conn, err := dialer.DialContext(ctx, b.socketType, b.address)
if err != nil {
return err
}
@ -211,11 +213,11 @@ func (b *Backend) reconnect() error {
return nil
}
func (b *Backend) Reload() error {
func (b *Backend) Reload(ctx context.Context) error {
b.Lock()
defer b.Unlock()
err := b.reconnect()
err := b.reconnect(ctx)
return err
}
@ -240,7 +242,7 @@ func (b *Backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *Backend) Invalidate() {
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil

View File

@ -2,6 +2,7 @@ package syslog
import (
"bytes"
"context"
"fmt"
"strconv"
"sync"
@ -12,7 +13,7 @@ import (
"github.com/hashicorp/vault/logical"
)
func Factory(conf *audit.BackendConfig) (audit.Backend, error) {
func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
if conf.SaltConfig == nil {
return nil, fmt.Errorf("nil salt config")
}
@ -115,7 +116,7 @@ func (b *Backend) GetHash(data string) (string, error) {
return audit.HashString(salt, data), nil
}
func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr error) error {
func (b *Backend) LogRequest(_ context.Context, auth *logical.Auth, req *logical.Request, outerErr error) error {
var buf bytes.Buffer
if err := b.formatter.FormatRequest(&buf, b.formatConfig, auth, req, outerErr); err != nil {
return err
@ -126,7 +127,7 @@ func (b *Backend) LogRequest(auth *logical.Auth, req *logical.Request, outerErr
return err
}
func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) error {
func (b *Backend) LogResponse(_ context.Context, auth *logical.Auth, req *logical.Request, resp *logical.Response, err error) error {
var buf bytes.Buffer
if err := b.formatter.FormatResponse(&buf, b.formatConfig, auth, req, resp, err); err != nil {
return err
@ -137,7 +138,7 @@ func (b *Backend) LogResponse(auth *logical.Auth, req *logical.Request, resp *lo
return err
}
func (b *Backend) Reload() error {
func (b *Backend) Reload(_ context.Context) error {
return nil
}
@ -161,7 +162,7 @@ func (b *Backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *Backend) Invalidate() {
func (b *Backend) Invalidate(_ context.Context) {
b.saltMutex.Lock()
defer b.saltMutex.Unlock()
b.salt = nil

View File

@ -1,6 +1,7 @@
package appId
import (
"context"
"sync"
"github.com/hashicorp/vault/helper/salt"
@ -8,12 +9,12 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -115,7 +116,7 @@ func (b *backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case salt.DefaultLocation:
b.SaltMutex.Lock()

View File

@ -5,6 +5,7 @@ import (
"fmt"
"testing"
"github.com/hashicorp/vault/helper/salt"
"github.com/hashicorp/vault/logical"
logicaltest "github.com/hashicorp/vault/logical/testing"
)
@ -13,13 +14,13 @@ func TestBackend_basic(t *testing.T) {
var b *backend
var err error
var storage logical.Storage
factory := func(conf *logical.BackendConfig) (logical.Backend, error) {
factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err = Backend(conf)
if err != nil {
t.Fatal(err)
}
storage = conf.StorageView
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -53,11 +54,11 @@ func TestBackend_basic(t *testing.T) {
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
salt, err := b.Salt()
bSalt, err := b.Salt()
if err != nil {
t.Fatal(err)
}
if keys[0] != salt.SaltID("foo") {
if keys[0] != "s"+bSalt.SaltIDHashFunc("foo", salt.SHA256Hash) {
t.Fatal("value was improperly salted")
}
}

View File

@ -84,7 +84,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
userId := data.Get("user_id").(string)
var displayName string
if dispName, resp, err := b.verifyCredentials(req, appId, userId); err != nil {
if dispName, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@ -93,7 +93,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
}
// Get the policies associated with the app
policies, err := b.MapAppId.Policies(req.Storage, appId)
policies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
if err != nil {
return nil, err
}
@ -131,14 +131,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
// Skipping CIDR verification to enable renewal from machines other than
// the ones encompassed by CIDR block.
if _, resp, err := b.verifyCredentials(req, appId, userId); err != nil {
if _, resp, err := b.verifyCredentials(ctx, req, appId, userId); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
}
// Get the policies associated with the app
mapPolicies, err := b.MapAppId.Policies(req.Storage, appId)
mapPolicies, err := b.MapAppId.Policies(ctx, req.Storage, appId)
if err != nil {
return nil, err
}
@ -149,14 +149,14 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
}
func (b *backend) verifyCredentials(req *logical.Request, appId, userId string) (string, *logical.Response, error) {
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, appId, userId string) (string, *logical.Response, error) {
// Ensure both appId and userId are provided
if appId == "" || userId == "" {
return "", logical.ErrorResponse("missing 'app_id' or 'user_id'"), nil
}
// Look up the apps that this user is allowed to access
appsMap, err := b.MapUserId.Get(req.Storage, userId)
appsMap, err := b.MapUserId.Get(ctx, req.Storage, userId)
if err != nil {
return "", nil, err
}
@ -205,7 +205,7 @@ func (b *backend) verifyCredentials(req *logical.Request, appId, userId string)
}
// Get the raw data associated with the app
appRaw, err := b.MapAppId.Get(req.Storage, appId)
appRaw, err := b.MapAppId.Get(ctx, req.Storage, appId)
if err != nil {
return "", nil, err
}

View File

@ -1,6 +1,7 @@
package approle
import (
"context"
"sync"
"github.com/hashicorp/vault/helper/locksutil"
@ -49,12 +50,12 @@ type backend struct {
secretIDListingLock sync.RWMutex
}
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -125,7 +126,7 @@ func (b *backend) Salt() (*salt.Salt, error) {
return salt, nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case salt.DefaultLocation:
b.saltMutex.Lock()
@ -139,9 +140,9 @@ func (b *backend) invalidate(key string) {
// This could mean that the SecretID may live in the backend upto 1 min after its
// expiration. The deletion of SecretIDs are not security sensitive and it is okay
// to delay the removal of SecretIDs by a minute.
func (b *backend) periodicFunc(req *logical.Request) error {
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
// Initiate clean-up of expired SecretID entries
b.tidySecretID(req.Storage)
b.tidySecretID(ctx, req.Storage)
return nil
}

View File

@ -1,6 +1,7 @@
package approle
import (
"context"
"testing"
"github.com/hashicorp/vault/logical"
@ -17,7 +18,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
if b == nil {
t.Fatalf("failed to create backend")
}
err = b.Backend.Setup(config)
err = b.Backend.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@ -50,7 +51,7 @@ func (b *backend) pathLoginUpdateAliasLookahead(ctx context.Context, req *logica
// Returns the Auth object indicating the authentication and authorization information
// if the credentials provided are validated by the backend.
func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
role, roleName, metadata, _, err := b.validateCredentials(req, data)
role, roleName, metadata, _, err := b.validateCredentials(ctx, req, data)
if err != nil || role == nil {
return logical.ErrorResponse(fmt.Sprintf("failed to validate credentials: %v", err)), nil
}
@ -92,7 +93,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, data
defer lock.RUnlock()
// Ensure that the Role still exists.
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, fmt.Errorf("failed to validate role %s during renewal:%s", roleName, err)
}
@ -100,12 +101,17 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, data
return nil, fmt.Errorf("role %s does not exist during renewal", roleName)
}
resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(ctx, req, data)
if err != nil {
return nil, err
// If a period is provided, set that as part of resp.Auth.Period and return a
// response immediately. Let expiration manager handle renewal from there on.
if role.Period > time.Duration(0) {
resp := &logical.Response{
Auth: req.Auth,
}
resp.Auth.Period = role.Period
return resp, nil
}
resp.Auth.Period = role.Period
return resp, nil
return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(ctx, req, data)
}
const pathLoginHelpSys = "Issue a token based on the credentials supplied"

View File

@ -3,6 +3,7 @@ package approle
import (
"context"
"testing"
"time"
"github.com/hashicorp/vault/logical"
)
@ -48,12 +49,106 @@ func TestAppRole_RoleLogin(t *testing.T) {
RemoteAddr: "127.0.0.1",
},
}
resp, err = b.HandleRequest(context.Background(), loginReq)
loginResp, err := b.HandleRequest(context.Background(), loginReq)
if err != nil || (loginResp != nil && loginResp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, loginResp)
}
if loginResp.Auth == nil {
t.Fatalf("expected a non-nil auth object in the response")
}
// Test renewal
renewReq := generateRenewRequest(storage, loginResp.Auth)
resp, err = b.HandleRequest(context.Background(), renewReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Auth == nil {
if resp.Auth.TTL != 400*time.Second {
t.Fatalf("expected period value from response to be 400s, got: %s", resp.Auth.TTL)
}
///
// Test renewal with period
///
// Create role
period := 600 * time.Second
roleData := map[string]interface{}{
"policies": "a,b,c",
"period": period.String(),
}
roleReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "role/" + "role-period",
Storage: storage,
Data: roleData,
}
resp, err = b.HandleRequest(context.Background(), roleReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
roleRoleIDReq = &logical.Request{
Operation: logical.ReadOperation,
Path: "role/role-period/role-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleRoleIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
roleID = resp.Data["role_id"]
roleSecretIDReq = &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/role-period/secret-id",
Storage: storage,
}
resp, err = b.HandleRequest(context.Background(), roleSecretIDReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
secretID = resp.Data["secret_id"]
loginData["role_id"] = roleID
loginData["secret_id"] = secretID
loginResp, err = b.HandleRequest(context.Background(), loginReq)
if err != nil || (loginResp != nil && loginResp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, loginResp)
}
if loginResp.Auth == nil {
t.Fatalf("expected a non-nil auth object in the response")
}
renewReq = generateRenewRequest(storage, loginResp.Auth)
resp, err = b.HandleRequest(context.Background(), renewReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}
if resp.Auth.Period != period {
t.Fatalf("expected period value of %d in the response, got: %s", period, resp.Auth.Period)
}
}
func generateRenewRequest(s logical.Storage, auth *logical.Auth) *logical.Request {
renewReq := &logical.Request{
Operation: logical.RenewOperation,
Storage: s,
Auth: &logical.Auth{},
}
renewReq.Auth.InternalData = auth.InternalData
renewReq.Auth.Metadata = auth.Metadata
renewReq.Auth.LeaseOptions = auth.LeaseOptions
renewReq.Auth.Policies = auth.Policies
renewReq.Auth.IssueTime = time.Now()
renewReq.Auth.Period = auth.Period
return renewReq
}

View File

@ -523,7 +523,7 @@ func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Reque
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return false, err
}
@ -538,7 +538,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
lock.RLock()
defer lock.RUnlock()
roles, err := req.Storage.List("role/")
roles, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
@ -557,7 +557,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
defer lock.RUnlock()
// Get the role entry
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -580,7 +580,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
// Listing works one level at a time. Get the first level of data
// which could then be used to get the actual SecretID storage entries.
secretIDHMACs, err := req.Storage.List(fmt.Sprintf("secret_id/%s/", roleNameHMAC))
secretIDHMACs, err := req.Storage.List(ctx, fmt.Sprintf("secret_id/%s/", roleNameHMAC))
if err != nil {
return nil, err
}
@ -606,7 +606,7 @@ func (b *backend) pathRoleSecretIDList(ctx context.Context, req *logical.Request
secretIDLock.RLock()
result := secretIDStorageEntry{}
if entry, err := req.Storage.Get(entryIndex); err != nil {
if entry, err := req.Storage.Get(ctx, entryIndex); err != nil {
secretIDLock.RUnlock()
return nil, err
} else if entry == nil {
@ -643,7 +643,7 @@ func validateRoleConstraints(role *roleStorageEntry) error {
// setRoleEntry persists the role and creates an index from roleID to role
// name.
func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
func (b *backend) setRoleEntry(ctx context.Context, s logical.Storage, roleName string, role *roleStorageEntry, previousRoleID string) error {
if roleName == "" {
return fmt.Errorf("missing role name")
}
@ -667,7 +667,7 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
}
// Check if the index from the role_id to role already exists
roleIDIndex, err := b.roleIDEntry(s, role.RoleID)
roleIDIndex, err := b.roleIDEntry(ctx, s, role.RoleID)
if err != nil {
return fmt.Errorf("failed to read role_id index: %v", err)
}
@ -680,13 +680,13 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
// When role_id is getting updated, delete the old index before
// a new one is created
if previousRoleID != "" && previousRoleID != role.RoleID {
if err = b.roleIDEntryDelete(s, previousRoleID); err != nil {
if err = b.roleIDEntryDelete(ctx, s, previousRoleID); err != nil {
return fmt.Errorf("failed to delete previous role ID index")
}
}
// Save the role entry only after all the validations
if err = s.Put(entry); err != nil {
if err = s.Put(ctx, entry); err != nil {
return err
}
@ -697,20 +697,20 @@ func (b *backend) setRoleEntry(s logical.Storage, roleName string, role *roleSto
// Create a storage entry for reverse mapping of RoleID to role.
// Note that secondary index is created when the roleLock is held.
return b.setRoleIDEntry(s, role.RoleID, &roleIDStorageEntry{
return b.setRoleIDEntry(ctx, s, role.RoleID, &roleIDStorageEntry{
Name: roleName,
})
}
// roleEntry reads the role from storage
func (b *backend) roleEntry(s logical.Storage, roleName string) (*roleStorageEntry, error) {
func (b *backend) roleEntry(ctx context.Context, s logical.Storage, roleName string) (*roleStorageEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role_name")
}
var role roleStorageEntry
if entry, err := s.Get("role/" + strings.ToLower(roleName)); err != nil {
if entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName)); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@ -734,7 +734,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
defer lock.Unlock()
// Check if the role already exists
role, err := b.roleEntry(req.Storage, roleName)
role, err := b.roleEntry(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@ -855,7 +855,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
// Store the entry.
return resp, b.setRoleEntry(req.Storage, roleName, role, previousRoleID)
return resp, b.setRoleEntry(ctx, req.Storage, roleName, role, previousRoleID)
}
// pathRoleRead grabs a read lock and reads the options set on the role from the storage
@ -869,7 +869,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
lock.RLock()
lockRelease := lock.RUnlock
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
lockRelease()
return nil, err
@ -902,7 +902,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
// For sanity, verify that the index still exists. If the index is missing,
// add one and return a warning so it can be reported.
roleIDIndex, err := b.roleIDEntry(req.Storage, role.RoleID)
roleIDIndex, err := b.roleIDEntry(ctx, req.Storage, role.RoleID)
if err != nil {
lockRelease()
return nil, err
@ -915,7 +915,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
lockRelease = lock.Unlock
// Check again if the index is missing
roleIDIndex, err = b.roleIDEntry(req.Storage, role.RoleID)
roleIDIndex, err = b.roleIDEntry(ctx, req.Storage, role.RoleID)
if err != nil {
lockRelease()
return nil, err
@ -923,7 +923,7 @@ func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *
if roleIDIndex == nil {
// Create a new index
err = b.setRoleIDEntry(req.Storage, role.RoleID, &roleIDStorageEntry{
err = b.setRoleIDEntry(ctx, req.Storage, role.RoleID, &roleIDStorageEntry{
Name: roleName,
})
if err != nil {
@ -950,7 +950,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -959,17 +959,17 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
}
// Just before the role is deleted, remove all the SecretIDs issued as part of the role.
if err = b.flushRoleSecrets(req.Storage, roleName, role.HMACKey); err != nil {
if err = b.flushRoleSecrets(ctx, req.Storage, roleName, role.HMACKey); err != nil {
return nil, fmt.Errorf("failed to invalidate the secrets belonging to role %q: %v", roleName, err)
}
// Delete the reverse mapping from RoleID to the role
if err = b.roleIDEntryDelete(req.Storage, role.RoleID); err != nil {
if err = b.roleIDEntryDelete(ctx, req.Storage, role.RoleID); err != nil {
return nil, fmt.Errorf("failed to delete the mapping from RoleID to role %q: %v", roleName, err)
}
// After deleting the SecretIDs and the RoleID, delete the role itself
if err = req.Storage.Delete("role/" + strings.ToLower(roleName)); err != nil {
if err = req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName)); err != nil {
return nil, err
}
@ -993,7 +993,7 @@ func (b *backend) pathRoleSecretIDLookupUpdate(ctx context.Context, req *logical
defer lock.RUnlock()
// Fetch the role
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1020,16 +1020,16 @@ func (b *backend) pathRoleSecretIDLookupUpdate(ctx context.Context, req *logical
// Create the index at which the secret_id would've been stored
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
return b.secretIDCommon(req.Storage, entryIndex, secretIDHMAC)
return b.secretIDCommon(ctx, req.Storage, entryIndex, secretIDHMAC)
}
func (b *backend) secretIDCommon(s logical.Storage, entryIndex, secretIDHMAC string) (*logical.Response, error) {
func (b *backend) secretIDCommon(ctx context.Context, s logical.Storage, entryIndex, secretIDHMAC string) (*logical.Response, error) {
lock := b.secretIDLock(secretIDHMAC)
lock.RLock()
defer lock.RUnlock()
result := secretIDStorageEntry{}
if entry, err := s.Get(entryIndex); err != nil {
if entry, err := s.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@ -1075,7 +1075,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
roleLock.RLock()
defer roleLock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1100,7 +1100,7 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
defer lock.Unlock()
result := secretIDStorageEntry{}
if entry, err := req.Storage.Get(entryIndex); err != nil {
if entry, err := req.Storage.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@ -1109,12 +1109,12 @@ func (b *backend) pathRoleSecretIDDestroyUpdateDelete(ctx context.Context, req *
}
// Delete the accessor of the SecretID first
if err := b.deleteSecretIDAccessorEntry(req.Storage, result.SecretIDAccessor); err != nil {
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, result.SecretIDAccessor); err != nil {
return nil, err
}
// Delete the storage entry that corresponds to the SecretID
if err := req.Storage.Delete(entryIndex); err != nil {
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
}
@ -1142,7 +1142,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1150,7 +1150,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
return nil, fmt.Errorf("role %q does not exist", roleName)
}
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
accessorEntry, err := b.secretIDAccessorEntry(ctx, req.Storage, secretIDAccessor)
if err != nil {
return nil, err
}
@ -1165,7 +1165,7 @@ func (b *backend) pathRoleSecretIDAccessorLookupUpdate(ctx context.Context, req
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, accessorEntry.SecretIDHMAC)
return b.secretIDCommon(req.Storage, entryIndex, accessorEntry.SecretIDHMAC)
return b.secretIDCommon(ctx, req.Storage, entryIndex, accessorEntry.SecretIDHMAC)
}
func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1183,7 +1183,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
// Get the role details to fetch the RoleID and accessor to get
// the HMACed SecretID.
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1191,7 +1191,7 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
return nil, fmt.Errorf("role %q does not exist", roleName)
}
accessorEntry, err := b.secretIDAccessorEntry(req.Storage, secretIDAccessor)
accessorEntry, err := b.secretIDAccessorEntry(ctx, req.Storage, secretIDAccessor)
if err != nil {
return nil, err
}
@ -1211,12 +1211,12 @@ func (b *backend) pathRoleSecretIDAccessorDestroyUpdateDelete(ctx context.Contex
defer lock.Unlock()
// Delete the accessor of the SecretID first
if err := b.deleteSecretIDAccessorEntry(req.Storage, secretIDAccessor); err != nil {
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, secretIDAccessor); err != nil {
return nil, err
}
// Delete the storage entry that corresponds to the SecretID
if err := req.Storage.Delete(entryIndex); err != nil {
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
return nil, fmt.Errorf("failed to delete secret_id: %v", err)
}
@ -1234,7 +1234,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical.
defer lock.Unlock()
// Re-read the role after grabbing the lock
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1257,7 +1257,7 @@ func (b *backend) pathRoleBoundCIDRListUpdate(ctx context.Context, req *logical.
}
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleBoundCIDRListRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1270,7 +1270,7 @@ func (b *backend) pathRoleBoundCIDRListRead(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1293,7 +1293,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical.
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1304,7 +1304,7 @@ func (b *backend) pathRoleBoundCIDRListDelete(ctx context.Context, req *logical.
// Deleting a field implies setting the value to it's default value.
role.BoundCIDRList = data.GetDefaultOrZero("bound_cidr_list").(string)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1317,7 +1317,7 @@ func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1327,7 +1327,7 @@ func (b *backend) pathRoleBindSecretIDUpdate(ctx context.Context, req *logical.R
if bindSecretIDRaw, ok := data.GetOk("bind_secret_id"); ok {
role.BindSecretID = bindSecretIDRaw.(bool)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing bind_secret_id"), nil
}
@ -1343,7 +1343,7 @@ func (b *backend) pathRoleBindSecretIDRead(ctx context.Context, req *logical.Req
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1366,7 +1366,7 @@ func (b *backend) pathRoleBindSecretIDDelete(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1377,7 +1377,7 @@ func (b *backend) pathRoleBindSecretIDDelete(ctx context.Context, req *logical.R
// Deleting a field implies setting the value to it's default value.
role.BindSecretID = data.GetDefaultOrZero("bind_secret_id").(bool)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1390,7 +1390,7 @@ func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1405,7 +1405,7 @@ func (b *backend) pathRolePoliciesUpdate(ctx context.Context, req *logical.Reque
role.Policies = policyutil.ParsePolicies(policiesRaw)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRolePoliciesRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1418,7 +1418,7 @@ func (b *backend) pathRolePoliciesRead(ctx context.Context, req *logical.Request
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1441,7 +1441,7 @@ func (b *backend) pathRolePoliciesDelete(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1451,7 +1451,7 @@ func (b *backend) pathRolePoliciesDelete(ctx context.Context, req *logical.Reque
role.Policies = []string{}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1464,7 +1464,7 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logica
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1477,7 +1477,7 @@ func (b *backend) pathRoleSecretIDNumUsesUpdate(ctx context.Context, req *logica
if role.SecretIDNumUses < 0 {
return logical.ErrorResponse("secret_id_num_uses cannot be negative"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing secret_id_num_uses"), nil
}
@ -1493,7 +1493,7 @@ func (b *backend) pathRoleRoleIDUpdate(ctx context.Context, req *logical.Request
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1507,7 +1507,7 @@ func (b *backend) pathRoleRoleIDUpdate(ctx context.Context, req *logical.Request
return logical.ErrorResponse("missing role_id"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, previousRoleID)
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, previousRoleID)
}
func (b *backend) pathRoleRoleIDRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1520,7 +1520,7 @@ func (b *backend) pathRoleRoleIDRead(ctx context.Context, req *logical.Request,
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1543,7 +1543,7 @@ func (b *backend) pathRoleSecretIDNumUsesRead(ctx context.Context, req *logical.
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1566,7 +1566,7 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(ctx context.Context, req *logica
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1576,7 +1576,7 @@ func (b *backend) pathRoleSecretIDNumUsesDelete(ctx context.Context, req *logica
role.SecretIDNumUses = data.GetDefaultOrZero("secret_id_num_uses").(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1589,7 +1589,7 @@ func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1599,7 +1599,7 @@ func (b *backend) pathRoleSecretIDTTLUpdate(ctx context.Context, req *logical.Re
if secretIDTTLRaw, ok := data.GetOk("secret_id_ttl"); ok {
role.SecretIDTTL = time.Second * time.Duration(secretIDTTLRaw.(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing secret_id_ttl"), nil
}
@ -1615,7 +1615,7 @@ func (b *backend) pathRoleSecretIDTTLRead(ctx context.Context, req *logical.Requ
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1639,7 +1639,7 @@ func (b *backend) pathRoleSecretIDTTLDelete(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1649,7 +1649,7 @@ func (b *backend) pathRoleSecretIDTTLDelete(ctx context.Context, req *logical.Re
role.SecretIDTTL = time.Second * time.Duration(data.GetDefaultOrZero("secret_id_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1662,7 +1662,7 @@ func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1675,7 +1675,7 @@ func (b *backend) pathRolePeriodUpdate(ctx context.Context, req *logical.Request
if role.Period > b.System().MaxLeaseTTL() {
return logical.ErrorResponse(fmt.Sprintf("period of %q is greater than the backend's maximum lease TTL of %q", role.Period.String(), b.System().MaxLeaseTTL().String())), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing period"), nil
}
@ -1691,7 +1691,7 @@ func (b *backend) pathRolePeriodRead(ctx context.Context, req *logical.Request,
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1715,7 +1715,7 @@ func (b *backend) pathRolePeriodDelete(ctx context.Context, req *logical.Request
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1725,7 +1725,7 @@ func (b *backend) pathRolePeriodDelete(ctx context.Context, req *logical.Request
role.Period = time.Second * time.Duration(data.GetDefaultOrZero("period").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1738,7 +1738,7 @@ func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1748,7 +1748,7 @@ func (b *backend) pathRoleTokenNumUsesUpdate(ctx context.Context, req *logical.R
if tokenNumUsesRaw, ok := data.GetOk("token_num_uses"); ok {
role.TokenNumUses = tokenNumUsesRaw.(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing token_num_uses"), nil
}
@ -1764,7 +1764,7 @@ func (b *backend) pathRoleTokenNumUsesRead(ctx context.Context, req *logical.Req
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1787,7 +1787,7 @@ func (b *backend) pathRoleTokenNumUsesDelete(ctx context.Context, req *logical.R
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1797,7 +1797,7 @@ func (b *backend) pathRoleTokenNumUsesDelete(ctx context.Context, req *logical.R
role.TokenNumUses = data.GetDefaultOrZero("token_num_uses").(int)
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1810,7 +1810,7 @@ func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1823,7 +1823,7 @@ func (b *backend) pathRoleTokenTTLUpdate(ctx context.Context, req *logical.Reque
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
return logical.ErrorResponse("token_ttl should not be greater than token_max_ttl"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing token_ttl"), nil
}
@ -1839,7 +1839,7 @@ func (b *backend) pathRoleTokenTTLRead(ctx context.Context, req *logical.Request
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1863,7 +1863,7 @@ func (b *backend) pathRoleTokenTTLDelete(ctx context.Context, req *logical.Reque
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1873,7 +1873,7 @@ func (b *backend) pathRoleTokenTTLDelete(ctx context.Context, req *logical.Reque
role.TokenTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1886,7 +1886,7 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1899,7 +1899,7 @@ func (b *backend) pathRoleTokenMaxTTLUpdate(ctx context.Context, req *logical.Re
if role.TokenMaxTTL > time.Duration(0) && role.TokenTTL > role.TokenMaxTTL {
return logical.ErrorResponse("token_max_ttl should be greater than or equal to token_ttl"), nil
}
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
} else {
return logical.ErrorResponse("missing token_max_ttl"), nil
}
@ -1915,7 +1915,7 @@ func (b *backend) pathRoleTokenMaxTTLRead(ctx context.Context, req *logical.Requ
lock.RLock()
defer lock.RUnlock()
if role, err := b.roleEntry(req.Storage, strings.ToLower(roleName)); err != nil {
if role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName)); err != nil {
return nil, err
} else if role == nil {
return nil, nil
@ -1939,7 +1939,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(ctx context.Context, req *logical.Re
lock.Lock()
defer lock.Unlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -1949,7 +1949,7 @@ func (b *backend) pathRoleTokenMaxTTLDelete(ctx context.Context, req *logical.Re
role.TokenMaxTTL = time.Second * time.Duration(data.GetDefaultOrZero("token_max_ttl").(int))
return nil, b.setRoleEntry(req.Storage, roleName, role, "")
return nil, b.setRoleEntry(ctx, req.Storage, roleName, role, "")
}
func (b *backend) pathRoleSecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1978,7 +1978,7 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(req.Storage, strings.ToLower(roleName))
role, err := b.roleEntry(ctx, req.Storage, strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -2026,7 +2026,7 @@ func (b *backend) handleRoleSecretIDCommon(ctx context.Context, req *logical.Req
roleName = strings.ToLower(roleName)
}
if secretIDStorage, err = b.registerSecretIDEntry(req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
if secretIDStorage, err = b.registerSecretIDEntry(ctx, req.Storage, roleName, secretID, role.HMACKey, secretIDStorage); err != nil {
return nil, fmt.Errorf("failed to store secret_id: %v", err)
}
@ -2047,7 +2047,7 @@ func (b *backend) roleLock(roleName string) *locksutil.LockEntry {
}
// setRoleIDEntry creates a storage entry that maps RoleID to Role
func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *roleIDStorageEntry) error {
func (b *backend) setRoleIDEntry(ctx context.Context, s logical.Storage, roleID string, roleIDEntry *roleIDStorageEntry) error {
lock := b.roleIDLock(roleID)
lock.Lock()
defer lock.Unlock()
@ -2062,14 +2062,14 @@ func (b *backend) setRoleIDEntry(s logical.Storage, roleID string, roleIDEntry *
if err != nil {
return err
}
if err = s.Put(entry); err != nil {
if err = s.Put(ctx, entry); err != nil {
return err
}
return nil
}
// roleIDEntry is used to read the storage entry that maps RoleID to Role
func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageEntry, error) {
func (b *backend) roleIDEntry(ctx context.Context, s logical.Storage, roleID string) (*roleIDStorageEntry, error) {
if roleID == "" {
return nil, fmt.Errorf("missing roleID")
}
@ -2086,7 +2086,7 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
}
entryIndex := "role_id/" + salt.SaltID(roleID)
if entry, err := s.Get(entryIndex); err != nil {
if entry, err := s.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@ -2099,7 +2099,7 @@ func (b *backend) roleIDEntry(s logical.Storage, roleID string) (*roleIDStorageE
// roleIDEntryDelete is used to remove the secondary index that maps the
// RoleID to the Role itself.
func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
func (b *backend) roleIDEntryDelete(ctx context.Context, s logical.Storage, roleID string) error {
if roleID == "" {
return fmt.Errorf("missing roleID")
}
@ -2114,7 +2114,7 @@ func (b *backend) roleIDEntryDelete(s logical.Storage, roleID string) error {
}
entryIndex := "role_id/" + salt.SaltID(roleID)
return s.Delete(entryIndex)
return s.Delete(ctx, entryIndex)
}
var roleHelp = map[string][2]string{

View File

@ -26,7 +26,7 @@ func TestApprole_RoleNameLowerCasing(t *testing.T) {
Policies: []string{"default"},
BindSecretID: true,
}
err = b.setRoleEntry(storage, "testRoleName", role, "")
err = b.setRoleEntry(context.Background(), storage, "testRoleName", role, "")
if err != nil {
t.Fatal(err)
}
@ -208,7 +208,7 @@ func TestAppRole_RoleReadSetIndex(t *testing.T) {
roleID := resp.Data["role_id"].(string)
// Delete the role ID index
err = b.roleIDEntryDelete(storage, roleID)
err = b.roleIDEntryDelete(context.Background(), storage, roleID)
if err != nil {
t.Fatal(err)
}
@ -225,7 +225,7 @@ func TestAppRole_RoleReadSetIndex(t *testing.T) {
t.Fatalf("bad: expected a warning in the response")
}
roleIDIndex, err := b.roleIDEntry(storage, roleID)
roleIDIndex, err := b.roleIDEntry(context.Background(), storage, roleID)
if err != nil {
t.Fatal(err)
}

View File

@ -25,7 +25,7 @@ func pathTidySecretID(b *backend) *framework.Path {
}
// tidySecretID is used to delete entries in the whitelist that are expired.
func (b *backend) tidySecretID(s logical.Storage) error {
func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error {
grabbed := atomic.CompareAndSwapUint32(&b.tidySecretIDCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(&b.tidySecretIDCASGuard, 0)
@ -33,7 +33,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
return fmt.Errorf("SecretID tidy operation already running")
}
roleNameHMACs, err := s.List("secret_id/")
roleNameHMACs, err := s.List(ctx, "secret_id/")
if err != nil {
return err
}
@ -41,7 +41,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
var result error
for _, roleNameHMAC := range roleNameHMACs {
// roleNameHMAC will already have a '/' suffix. Don't append another one.
secretIDHMACs, err := s.List(fmt.Sprintf("secret_id/%s", roleNameHMAC))
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("secret_id/%s", roleNameHMAC))
if err != nil {
return err
}
@ -52,7 +52,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
lock.Lock()
// roleNameHMAC will already have a '/' suffix. Don't append another one.
entryIndex := fmt.Sprintf("secret_id/%s%s", roleNameHMAC, secretIDHMAC)
secretIDEntry, err := s.Get(entryIndex)
secretIDEntry, err := s.Get(ctx, entryIndex)
if err != nil {
lock.Unlock()
return fmt.Errorf("error fetching SecretID %s: %s", secretIDHMAC, err)
@ -77,7 +77,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
// ExpirationTime not being set indicates non-expiring SecretIDs
if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) {
if err := s.Delete(entryIndex); err != nil {
if err := s.Delete(ctx, entryIndex); err != nil {
lock.Unlock()
return fmt.Errorf("error deleting SecretID %s from storage: %s", secretIDHMAC, err)
}
@ -90,7 +90,7 @@ func (b *backend) tidySecretID(s logical.Storage) error {
// pathTidySecretIDUpdate is used to delete the expired SecretID entries
func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidySecretID(req.Storage)
return nil, b.tidySecretID(ctx, req.Storage)
}
const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries."

View File

@ -1,6 +1,7 @@
package approle
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
@ -68,9 +69,9 @@ type secretIDAccessorStorageEntry struct {
}
// Checks if the Role represented by the RoleID still exists
func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorageEntry, string, error) {
func (b *backend) validateRoleID(ctx context.Context, s logical.Storage, roleID string) (*roleStorageEntry, string, error) {
// Look for the storage entry that maps the roleID to role
roleIDIndex, err := b.roleIDEntry(s, roleID)
roleIDIndex, err := b.roleIDEntry(ctx, s, roleID)
if err != nil {
return nil, "", err
}
@ -82,7 +83,7 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
lock.RLock()
defer lock.RUnlock()
role, err := b.roleEntry(s, roleIDIndex.Name)
role, err := b.roleEntry(ctx, s, roleIDIndex.Name)
if err != nil {
return nil, "", err
}
@ -94,7 +95,7 @@ func (b *backend) validateRoleID(s logical.Storage, roleID string) (*roleStorage
}
// Validates the supplied RoleID and SecretID
func (b *backend) validateCredentials(req *logical.Request, data *framework.FieldData) (*roleStorageEntry, string, map[string]string, string, error) {
func (b *backend) validateCredentials(ctx context.Context, req *logical.Request, data *framework.FieldData) (*roleStorageEntry, string, map[string]string, string, error) {
metadata := make(map[string]string)
// RoleID must be supplied during every login
roleID := strings.TrimSpace(data.Get("role_id").(string))
@ -103,7 +104,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
}
// Validate the RoleID and get the Role entry
role, roleName, err := b.validateRoleID(req.Storage, roleID)
role, roleName, err := b.validateRoleID(ctx, req.Storage, roleID)
if err != nil {
return nil, "", metadata, "", err
}
@ -132,7 +133,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
// Check if the SecretID supplied is valid. If use limit was specified
// on the SecretID, it will be decremented in this call.
var valid bool
valid, metadata, err = b.validateBindSecretID(req, roleName, secretID, role.HMACKey, role.BoundCIDRList)
valid, metadata, err = b.validateBindSecretID(ctx, req, roleName, secretID, role.HMACKey, role.BoundCIDRList)
if err != nil {
return nil, "", metadata, "", err
}
@ -160,7 +161,7 @@ func (b *backend) validateCredentials(req *logical.Request, data *framework.Fiel
}
// validateBindSecretID is used to determine if the given SecretID is a valid one.
func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
func (b *backend) validateBindSecretID(ctx context.Context, req *logical.Request, roleName, secretID,
hmacKey, roleBoundCIDRList string) (bool, map[string]string, error) {
secretIDHMAC, err := createHMAC(hmacKey, secretID)
if err != nil {
@ -180,7 +181,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
lock := b.secretIDLock(secretIDHMAC)
lock.RLock()
result, err := b.nonLockedSecretIDStorageEntry(req.Storage, roleNameHMAC, secretIDHMAC)
result, err := b.nonLockedSecretIDStorageEntry(ctx, req.Storage, roleNameHMAC, secretIDHMAC)
if err != nil {
lock.RUnlock()
return false, nil, err
@ -225,7 +226,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
defer lock.Unlock()
// Lock switching may change the data. Refresh the contents.
result, err = b.nonLockedSecretIDStorageEntry(req.Storage, roleNameHMAC, secretIDHMAC)
result, err = b.nonLockedSecretIDStorageEntry(ctx, req.Storage, roleNameHMAC, secretIDHMAC)
if err != nil {
return false, nil, err
}
@ -238,10 +239,10 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
// requests to use the same SecretID will fail.
if result.SecretIDNumUses == 1 {
// Delete the secret IDs accessor first
if err := b.deleteSecretIDAccessorEntry(req.Storage, result.SecretIDAccessor); err != nil {
if err := b.deleteSecretIDAccessorEntry(ctx, req.Storage, result.SecretIDAccessor); err != nil {
return false, nil, err
}
if err := req.Storage.Delete(entryIndex); err != nil {
if err := req.Storage.Delete(ctx, entryIndex); err != nil {
return false, nil, fmt.Errorf("failed to delete secret ID: %v", err)
}
} else {
@ -250,7 +251,7 @@ func (b *backend) validateBindSecretID(req *logical.Request, roleName, secretID,
result.LastUpdatedTime = time.Now()
if entry, err := logical.StorageEntryJSON(entryIndex, &result); err != nil {
return false, nil, fmt.Errorf("failed to decrement the use count for secret ID %q", secretID)
} else if err = req.Storage.Put(entry); err != nil {
} else if err = req.Storage.Put(ctx, entry); err != nil {
return false, nil, fmt.Errorf("failed to decrement the use count for secret ID %q", secretID)
}
}
@ -320,7 +321,7 @@ func (b *backend) secretIDAccessorLock(secretIDAccessor string) *locksutil.LockE
// storage. The entry will be indexed based on the given HMACs of both role
// name and the secret ID. This method will not acquire secret ID lock to fetch
// the storage entry. Locks need to be acquired before calling this method.
func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC, secretIDHMAC string) (*secretIDStorageEntry, error) {
func (b *backend) nonLockedSecretIDStorageEntry(ctx context.Context, s logical.Storage, roleNameHMAC, secretIDHMAC string) (*secretIDStorageEntry, error) {
if secretIDHMAC == "" {
return nil, fmt.Errorf("missing secret ID HMAC")
}
@ -332,7 +333,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
// Prepare the storage index at which the secret ID will be stored
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
entry, err := s.Get(entryIndex)
entry, err := s.Get(ctx, entryIndex)
if err != nil {
return nil, err
}
@ -360,7 +361,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
}
if persistNeeded {
if err := b.nonLockedSetSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC, &result); err != nil {
if err := b.nonLockedSetSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC, &result); err != nil {
return nil, fmt.Errorf("failed to upgrade role storage entry %s", err)
}
}
@ -373,7 +374,7 @@ func (b *backend) nonLockedSecretIDStorageEntry(s logical.Storage, roleNameHMAC,
// role name and the secret ID. This method will not acquire secret ID lock to
// create/update the storage entry. Locks need to be acquired before calling
// this method.
func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHMAC, secretIDHMAC string, secretEntry *secretIDStorageEntry) error {
func (b *backend) nonLockedSetSecretIDStorageEntry(ctx context.Context, s logical.Storage, roleNameHMAC, secretIDHMAC string, secretEntry *secretIDStorageEntry) error {
if secretIDHMAC == "" {
return fmt.Errorf("missing secret ID HMAC")
}
@ -390,7 +391,7 @@ func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHM
if entry, err := logical.StorageEntryJSON(entryIndex, secretEntry); err != nil {
return err
} else if err = s.Put(entry); err != nil {
} else if err = s.Put(ctx, entry); err != nil {
return err
}
@ -398,7 +399,7 @@ func (b *backend) nonLockedSetSecretIDStorageEntry(s logical.Storage, roleNameHM
}
// registerSecretIDEntry creates a new storage entry for the given SecretID.
func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, hmacKey string, secretEntry *secretIDStorageEntry) (*secretIDStorageEntry, error) {
func (b *backend) registerSecretIDEntry(ctx context.Context, s logical.Storage, roleName, secretID, hmacKey string, secretEntry *secretIDStorageEntry) (*secretIDStorageEntry, error) {
secretIDHMAC, err := createHMAC(hmacKey, secretID)
if err != nil {
return nil, fmt.Errorf("failed to create HMAC of secret ID: %v", err)
@ -411,7 +412,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
lock := b.secretIDLock(secretIDHMAC)
lock.RLock()
entry, err := b.nonLockedSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC)
entry, err := b.nonLockedSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC)
if err != nil {
lock.RUnlock()
return nil, err
@ -428,7 +429,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
defer lock.Unlock()
// But before saving a new entry, check if the secretID entry was created during the lock switch.
entry, err = b.nonLockedSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC)
entry, err = b.nonLockedSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC)
if err != nil {
return nil, err
}
@ -457,11 +458,11 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
}
// Before storing the SecretID, store its accessor.
if err := b.createSecretIDAccessorEntry(s, secretEntry, secretIDHMAC); err != nil {
if err := b.createSecretIDAccessorEntry(ctx, s, secretEntry, secretIDHMAC); err != nil {
return nil, err
}
if err := b.nonLockedSetSecretIDStorageEntry(s, roleNameHMAC, secretIDHMAC, secretEntry); err != nil {
if err := b.nonLockedSetSecretIDStorageEntry(ctx, s, roleNameHMAC, secretIDHMAC, secretEntry); err != nil {
return nil, err
}
@ -470,7 +471,7 @@ func (b *backend) registerSecretIDEntry(s logical.Storage, roleName, secretID, h
// secretIDAccessorEntry is used to read the storage entry that maps an
// accessor to a secret_id.
func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor string) (*secretIDAccessorStorageEntry, error) {
func (b *backend) secretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) (*secretIDAccessorStorageEntry, error) {
if secretIDAccessor == "" {
return nil, fmt.Errorf("missing secretIDAccessor")
}
@ -488,7 +489,7 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
accessorLock.RLock()
defer accessorLock.RUnlock()
if entry, err := s.Get(entryIndex); err != nil {
if entry, err := s.Get(ctx, entryIndex); err != nil {
return nil, err
} else if entry == nil {
return nil, nil
@ -502,7 +503,7 @@ func (b *backend) secretIDAccessorEntry(s logical.Storage, secretIDAccessor stri
// createSecretIDAccessorEntry creates an identifier for the SecretID. A storage index,
// mapping the accessor to the SecretID is also created. This method should
// be called when the lock for the corresponding SecretID is held.
func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretIDStorageEntry, secretIDHMAC string) error {
func (b *backend) createSecretIDAccessorEntry(ctx context.Context, s logical.Storage, entry *secretIDStorageEntry, secretIDHMAC string) error {
// Create a random accessor
accessorUUID, err := uuid.GenerateUUID()
if err != nil {
@ -525,7 +526,7 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
SecretIDHMAC: secretIDHMAC,
}); err != nil {
return err
} else if err = s.Put(entry); err != nil {
} else if err = s.Put(ctx, entry); err != nil {
return fmt.Errorf("failed to persist accessor index entry: %v", err)
}
@ -533,7 +534,7 @@ func (b *backend) createSecretIDAccessorEntry(s logical.Storage, entry *secretID
}
// deleteSecretIDAccessorEntry deletes the storage index mapping the accessor to a SecretID.
func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccessor string) error {
func (b *backend) deleteSecretIDAccessorEntry(ctx context.Context, s logical.Storage, secretIDAccessor string) error {
salt, err := b.Salt()
if err != nil {
return err
@ -545,7 +546,7 @@ func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccesso
defer accessorLock.Unlock()
// Delete the accessor of the SecretID first
if err := s.Delete(accessorEntryIndex); err != nil {
if err := s.Delete(ctx, accessorEntryIndex); err != nil {
return fmt.Errorf("failed to delete accessor storage entry: %v", err)
}
@ -554,7 +555,7 @@ func (b *backend) deleteSecretIDAccessorEntry(s logical.Storage, secretIDAccesso
// flushRoleSecrets deletes all the SecretIDs that belong to the given
// RoleID.
func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string) error {
func (b *backend) flushRoleSecrets(ctx context.Context, s logical.Storage, roleName, hmacKey string) error {
roleNameHMAC, err := createHMAC(hmacKey, roleName)
if err != nil {
return fmt.Errorf("failed to create HMAC of role_name: %v", err)
@ -564,7 +565,7 @@ func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string)
b.secretIDListingLock.RLock()
defer b.secretIDListingLock.RUnlock()
secretIDHMACs, err := s.List(fmt.Sprintf("secret_id/%s/", roleNameHMAC))
secretIDHMACs, err := s.List(ctx, fmt.Sprintf("secret_id/%s/", roleNameHMAC))
if err != nil {
return err
}
@ -573,7 +574,7 @@ func (b *backend) flushRoleSecrets(s logical.Storage, roleName, hmacKey string)
lock := b.secretIDLock(secretIDHMAC)
lock.Lock()
entryIndex := fmt.Sprintf("secret_id/%s/%s", roleNameHMAC, secretIDHMAC)
if err := s.Delete(entryIndex); err != nil {
if err := s.Delete(ctx, entryIndex); err != nil {
lock.Unlock()
return fmt.Errorf("error deleting SecretID %q from storage: %v", secretIDHMAC, err)
}

View File

@ -1,6 +1,7 @@
package awsauth
import (
"context"
"fmt"
"sync"
"time"
@ -13,12 +14,12 @@ import (
"github.com/patrickmn/go-cache"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b, err := Backend(conf)
if err != nil {
return nil, err
}
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -73,7 +74,7 @@ type backend struct {
// accounts using their IAM instance profile to get their credentials.
defaultAWSAccountID string
resolveArnToUniqueIDFunc func(logical.Storage, string) (string, error)
resolveArnToUniqueIDFunc func(context.Context, logical.Storage, string) (string, error)
}
func Backend(conf *logical.BackendConfig) (*backend, error) {
@ -138,13 +139,13 @@ func Backend(conf *logical.BackendConfig) (*backend, error) {
// not once in a minute, but once in an hour, controlled by 'tidyCooldownPeriod'.
// Tidying of blacklist and whitelist are by default enabled. This can be
// changed using `config/tidy/roletags` and `config/tidy/identities` endpoints.
func (b *backend) periodicFunc(req *logical.Request) error {
func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error {
// Run the tidy operations for the first time. Then run it when current
// time matches the nextTidyTime.
if b.nextTidyTime.IsZero() || !time.Now().Before(b.nextTidyTime) {
// safety_buffer defaults to 180 days for roletag blacklist
safety_buffer := 15552000
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(req.Storage)
tidyBlacklistConfigEntry, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return err
}
@ -160,12 +161,12 @@ func (b *backend) periodicFunc(req *logical.Request) error {
}
// tidy role tags if explicitly not disabled
if !skipBlacklistTidy {
b.tidyBlacklistRoleTag(req.Storage, safety_buffer)
b.tidyBlacklistRoleTag(ctx, req.Storage, safety_buffer)
}
// reset the safety_buffer to 72h
safety_buffer = 259200
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(req.Storage)
tidyWhitelistConfigEntry, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return err
}
@ -181,7 +182,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
}
// tidy identities if explicitly not disabled
if !skipWhitelistTidy {
b.tidyWhitelistIdentity(req.Storage, safety_buffer)
b.tidyWhitelistIdentity(ctx, req.Storage, safety_buffer)
}
// Update the time at which to run the tidy functions again.
@ -190,7 +191,7 @@ func (b *backend) periodicFunc(req *logical.Request) error {
return nil
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(ctx context.Context, key string) {
switch key {
case "config/client":
b.configMutex.Lock()
@ -203,7 +204,7 @@ func (b *backend) invalidate(key string) {
// Putting this here so we can inject a fake resolver into the backend for unit testing
// purposes
func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (string, error) {
func (b *backend) resolveArnToRealUniqueId(ctx context.Context, s logical.Storage, arn string) (string, error) {
entity, err := parseIamArn(arn)
if err != nil {
return "", err
@ -223,7 +224,7 @@ func (b *backend) resolveArnToRealUniqueId(s logical.Storage, arn string) (strin
if region == nil {
return "", fmt.Errorf("Unable to resolve partition %q to a region", entity.Partition)
}
iamClient, err := b.clientIAM(s, region.ID(), entity.AccountNumber)
iamClient, err := b.clientIAM(ctx, s, region.ID(), entity.AccountNumber)
if err != nil {
return "", err
}
@ -278,7 +279,7 @@ func getAnyRegionForAwsPartition(partitionId string) *endpoints.Region {
}
const backendHelp = `
aws-ec2 auth backend takes in PKCS#7 signature of an AWS EC2 instance and a client
aws-ec2 auth method takes in PKCS#7 signature of an AWS EC2 instance and a client
created nonce to authenticates the EC2 instance with Vault.
Authentication is backed by a preconfigured role in the backend. The role

View File

@ -30,7 +30,8 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -55,7 +56,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// read the created role entry
roleEntry, err := b.lockedAWSRole(storage, "abcd-123")
roleEntry, err := b.lockedAWSRole(context.Background(), storage, "abcd-123")
if err != nil {
t.Fatal(err)
}
@ -83,7 +84,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// parse the created role tag
rTag2, err := b.parseAndVerifyRoleTagValue(storage, val)
rTag2, err := b.parseAndVerifyRoleTagValue(context.Background(), storage, val)
if err != nil {
t.Fatal(err)
}
@ -122,7 +123,7 @@ func TestBackend_CreateParseVerifyRoleTag(t *testing.T) {
}
// get the entry of the newly created role entry
roleEntry2, err := b.lockedAWSRole(storage, "ami-6789")
roleEntry2, err := b.lockedAWSRole(context.Background(), storage, "ami-6789")
if err != nil {
t.Fatal(err)
}
@ -254,7 +255,8 @@ func TestBackend_ConfigTidyIdentities(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -308,7 +310,8 @@ func TestBackend_ConfigTidyRoleTags(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -362,7 +365,8 @@ func TestBackend_TidyIdentities(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -387,7 +391,8 @@ func TestBackend_TidyRoleTags(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -412,7 +417,8 @@ func TestBackend_ConfigClient(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -549,7 +555,8 @@ func TestBackend_pathConfigCertificate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -704,7 +711,8 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -763,7 +771,7 @@ func TestBackend_parseAndVerifyRoleTagValue(t *testing.T) {
tagValue := resp.Data["tag_value"].(string)
// parse the value and check if the verifiable values match
rTag, err := b.parseAndVerifyRoleTagValue(storage, tagValue)
rTag, err := b.parseAndVerifyRoleTagValue(context.Background(), storage, tagValue)
if err != nil {
t.Fatalf("err: %s", err)
}
@ -785,7 +793,8 @@ func TestBackend_PathRoleTag(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -850,7 +859,8 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -939,7 +949,7 @@ func TestBackend_PathBlacklistRoleTag(t *testing.T) {
}
// try to read the deleted entry
tagEntry, err := b.lockedBlacklistRoleTagEntry(storage, tag)
tagEntry, err := b.lockedBlacklistRoleTagEntry(context.Background(), storage, tag)
if err != nil {
t.Fatal(err)
}
@ -998,7 +1008,8 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndWhitelistIdentity(t *testing.
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -1190,7 +1201,8 @@ func TestBackend_pathStsConfig(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -1338,7 +1350,8 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -1442,11 +1455,11 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
}
fakeArn := "arn:aws:iam::123456789012:role/somePath/FakeRole"
fakeArnResolver := func(s logical.Storage, arn string) (string, error) {
fakeArnResolver := func(ctx context.Context, s logical.Storage, arn string) (string, error) {
if arn == fakeArn {
return fmt.Sprintf("FakeUniqueIdFor%s", fakeArn), nil
}
return b.resolveArnToRealUniqueId(s, arn)
return b.resolveArnToRealUniqueId(context.Background(), s, arn)
}
b.resolveArnToUniqueIDFunc = fakeArnResolver
@ -1615,6 +1628,40 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) {
if cachedArn == "" {
t.Errorf("got empty ARN back from user ID cache; expected full arn")
}
// Test for renewal with period
period := 600 * time.Second
roleData["period"] = period.String()
roleRequest.Path = "role/" + testValidRoleName
resp, err = b.HandleRequest(context.Background(), roleRequest)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("bad: failed to create wildcard role: resp:%#v\nerr:%v", resp, err)
}
loginData["role"] = testValidRoleName
resp, err = b.HandleRequest(context.Background(), loginRequest)
if err != nil {
t.Fatal(err)
}
if resp == nil || resp.Auth == nil || resp.IsError() {
t.Fatalf("bad: expected valid login: resp:%#v", resp)
}
renewReq = generateRenewRequest(storage, resp.Auth)
resp, err = b.pathLoginRenew(context.Background(), renewReq, empty_login_fd)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response from renew")
}
if resp.IsError() {
t.Fatalf("got error when renewing: %#v", *resp)
}
if resp.Auth.Period != period {
t.Fatalf("expected a period value of %s in the response, got: %s", period, resp.Auth.Period)
}
}
func generateRenewRequest(s logical.Storage, auth *logical.Auth) *logical.Request {
@ -1627,6 +1674,7 @@ func generateRenewRequest(s logical.Storage, auth *logical.Auth) *logical.Reques
renewReq.Auth.LeaseOptions = auth.LeaseOptions
renewReq.Auth.Policies = auth.Policies
renewReq.Auth.IssueTime = time.Now()
renewReq.Auth.Period = auth.Period
return renewReq
}

View File

@ -113,29 +113,51 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
func (h *CLIHandler) Help() string {
help := `
The AWS credential provider allows you to authenticate with
AWS IAM credentials. To use it, you specify valid AWS IAM credentials
in one of a number of ways. They can be specified explicitly on the
command line (which in general you should not do), via the standard AWS
environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and
AWS_SECURITY_TOKEN), via the ~/.aws/credentials file, or via an EC2
instance profile (in that order).
Usage: vault login -method=aws [CONFIG K=V...]
Example: vault auth -method=aws
The AWS auth method allows users to authenticate with AWS IAM
credentials. The AWS IAM credentials may be specified in a number of ways,
listed in order of precedence below:
If you need to explicitly pass in credentials, you would do it like this:
Example: vault auth -method=aws aws_access_key_id=<access key> aws_secret_access_key=<secret key> aws_security_token=<token>
1. Explicitly via the command line (not recommended)
Key/Value Pairs:
2. Via the standard AWS environment variables (AWS_ACCESS_KEY, etc.)
mount=aws The mountpoint for the AWS credential provider.
Defaults to "aws"
aws_access_key_id=<access key> Explicitly specified AWS access key
aws_secret_access_key=<secret key> Explicitly specified AWS secret key
aws_security_token=<token> Security token for temporary credentials
header_value The Value of the X-Vault-AWS-IAM-Server-ID header.
role The name of the role you're requesting a token for
`
3. Via the ~/.aws/credentials file
4. Via EC2 instance profile
Authenticate using locally stored credentials:
$ vault login -method=aws
Authenticate by passing keys:
$ vault login -method=aws aws_access_key_id=... aws_secret_access_key=...
Configuration:
aws_access_key_id=<string>
Explicit AWS access key ID
aws_secret_access_key=<string>
Explicit AWS secret access key
aws_security_token=<string>
Explicit AWS security token for temporary credentials
header_value=<string>
Value for the x-vault-aws-iam-server-id header in requests
mount=<string>
Path where the AWS credential method is mounted. This is usually provided
via the -path flag in the "vault login" command, but it can be specified
here as well. If specified here, it takes precedence over the value for
-path. The default value is "aws".
role=<string>
Name of the role to request a token against
`
return strings.TrimSpace(help)
}

View File

@ -1,6 +1,7 @@
package awsauth
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go/aws"
@ -21,13 +22,13 @@ import (
// * Static credentials from 'config/client'
// * Environment variables
// * Instance metadata role
func (b *backend) getRawClientConfig(s logical.Storage, region, clientType string) (*aws.Config, error) {
func (b *backend) getRawClientConfig(ctx context.Context, s logical.Storage, region, clientType string) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{
Region: region,
}
// Read the configured secret key and access key
config, err := b.nonLockedClientConfigEntry(s)
config, err := b.nonLockedClientConfigEntry(ctx, s)
if err != nil {
return nil, err
}
@ -71,9 +72,9 @@ func (b *backend) getRawClientConfig(s logical.Storage, region, clientType strin
// It uses getRawClientConfig to obtain config for the runtime environemnt, and if
// stsRole is a non-empty string, it will use AssumeRole to obtain a set of assumed
// credentials. The credentials will expire after 15 minutes but will auto-refresh.
func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
func (b *backend) getClientConfig(ctx context.Context, s logical.Storage, region, stsRole, accountID, clientType string) (*aws.Config, error) {
config, err := b.getRawClientConfig(s, region, clientType)
config, err := b.getRawClientConfig(ctx, s, region, clientType)
if err != nil {
return nil, err
}
@ -81,7 +82,7 @@ func (b *backend) getClientConfig(s logical.Storage, region, stsRole, accountID,
return nil, fmt.Errorf("could not compile valid credentials through the default provider chain")
}
stsConfig, err := b.getRawClientConfig(s, region, "sts")
stsConfig, err := b.getRawClientConfig(ctx, s, region, "sts")
if stsConfig == nil {
return nil, fmt.Errorf("could not configure STS client")
}
@ -160,9 +161,9 @@ func (b *backend) setCachedUserId(userId, arn string) {
}
}
func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string, error) {
func (b *backend) stsRoleForAccount(ctx context.Context, s logical.Storage, accountID string) (string, error) {
// Check if an STS configuration exists for the AWS account
sts, err := b.lockedAwsStsEntry(s, accountID)
sts, err := b.lockedAwsStsEntry(ctx, s, accountID)
if err != nil {
return "", fmt.Errorf("error fetching STS config for account ID %q: %q\n", accountID, err)
}
@ -174,8 +175,8 @@ func (b *backend) stsRoleForAccount(s logical.Storage, accountID string) (string
}
// clientEC2 creates a client to interact with AWS EC2 API
func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.EC2, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
func (b *backend) clientEC2(ctx context.Context, s logical.Storage, region, accountID string) (*ec2.EC2, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
}
@ -198,7 +199,7 @@ func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.E
// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "ec2")
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "ec2")
if err != nil {
return nil, err
@ -223,8 +224,8 @@ func (b *backend) clientEC2(s logical.Storage, region, accountID string) (*ec2.E
}
// clientIAM creates a client to interact with AWS IAM API
func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.IAM, error) {
stsRole, err := b.stsRoleForAccount(s, accountID)
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, region, accountID string) (*iam.IAM, error) {
stsRole, err := b.stsRoleForAccount(ctx, s, accountID)
if err != nil {
return nil, err
}
@ -247,7 +248,7 @@ func (b *backend) clientIAM(s logical.Storage, region, accountID string) (*iam.I
// Create an AWS config object using a chain of providers
var awsConfig *aws.Config
awsConfig, err = b.getClientConfig(s, region, stsRole, accountID, "iam")
awsConfig, err = b.getClientConfig(ctx, s, region, stsRole, accountID, "iam")
if err != nil {
return nil, err

View File

@ -9,7 +9,6 @@ import (
"math/big"
"strings"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -131,7 +130,7 @@ func (b *backend) pathConfigCertificateExistenceCheck(ctx context.Context, req *
return false, fmt.Errorf("missing cert_name")
}
entry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
entry, err := b.lockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
if err != nil {
return false, err
}
@ -143,7 +142,7 @@ func (b *backend) pathCertificatesList(ctx context.Context, req *logical.Request
b.configMutex.RLock()
defer b.configMutex.RUnlock()
certs, err := req.Storage.List("config/certificate/")
certs, err := req.Storage.List(ctx, "config/certificate/")
if err != nil {
return nil, err
}
@ -174,7 +173,7 @@ func decodePEMAndParseCertificate(certificate string) (*x509.Certificate, error)
// the PKCS7 signatures of the instance identity documents. This method will
// append the certificates registered using `config/certificate/<cert_name>`
// endpoint, along with the default certificate in the backend.
func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509.Certificate, error) {
func (b *backend) awsPublicCertificates(ctx context.Context, s logical.Storage, isPkcs bool) ([]*x509.Certificate, error) {
// Lock at beginning and use internal method so that we are consistent as
// we iterate through
b.configMutex.RLock()
@ -195,14 +194,14 @@ func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509
certs = append(certs, decodedCert)
// Get the list of all the registered certificates
registeredCerts, err := s.List("config/certificate/")
registeredCerts, err := s.List(ctx, "config/certificate/")
if err != nil {
return nil, err
}
// Iterate through each certificate, parse and append it to a slice
for _, cert := range registeredCerts {
certEntry, err := b.nonLockedAWSPublicCertificateEntry(s, cert)
certEntry, err := b.nonLockedAWSPublicCertificateEntry(ctx, s, cert)
if err != nil {
return nil, err
}
@ -226,7 +225,7 @@ func (b *backend) awsPublicCertificates(s logical.Storage, isPkcs bool) ([]*x509
// lockedSetAWSPublicCertificateEntry is used to store the AWS public key in
// the storage. This method acquires lock before creating or updating a storage
// entry.
func (b *backend) lockedSetAWSPublicCertificateEntry(s logical.Storage, certName string, certEntry *awsPublicCert) error {
func (b *backend) lockedSetAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string, certEntry *awsPublicCert) error {
if certName == "" {
return fmt.Errorf("missing certificate name")
}
@ -238,13 +237,13 @@ func (b *backend) lockedSetAWSPublicCertificateEntry(s logical.Storage, certName
b.configMutex.Lock()
defer b.configMutex.Unlock()
return b.nonLockedSetAWSPublicCertificateEntry(s, certName, certEntry)
return b.nonLockedSetAWSPublicCertificateEntry(ctx, s, certName, certEntry)
}
// nonLockedSetAWSPublicCertificateEntry is used to store the AWS public key in
// the storage. This method does not acquire lock before reading the storage.
// If locking is desired, use lockedSetAWSPublicCertificateEntry instead.
func (b *backend) nonLockedSetAWSPublicCertificateEntry(s logical.Storage, certName string, certEntry *awsPublicCert) error {
func (b *backend) nonLockedSetAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string, certEntry *awsPublicCert) error {
if certName == "" {
return fmt.Errorf("missing certificate name")
}
@ -261,24 +260,24 @@ func (b *backend) nonLockedSetAWSPublicCertificateEntry(s logical.Storage, certN
return fmt.Errorf("failed to create storage entry for AWS public key certificate")
}
return s.Put(entry)
return s.Put(ctx, entry)
}
// lockedAWSPublicCertificateEntry is used to get the configured AWS Public Key
// that is used to verify the PKCS#7 signature of the instance identity
// document.
func (b *backend) lockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
func (b *backend) lockedAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string) (*awsPublicCert, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedAWSPublicCertificateEntry(s, certName)
return b.nonLockedAWSPublicCertificateEntry(ctx, s, certName)
}
// nonLockedAWSPublicCertificateEntry reads the certificate information from
// the storage. This method does not acquire lock before reading the storage.
// If locking is desired, use lockedAWSPublicCertificateEntry instead.
func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName string) (*awsPublicCert, error) {
entry, err := s.Get("config/certificate/" + certName)
func (b *backend) nonLockedAWSPublicCertificateEntry(ctx context.Context, s logical.Storage, certName string) (*awsPublicCert, error) {
entry, err := s.Get(ctx, "config/certificate/"+certName)
if err != nil {
return nil, err
}
@ -298,7 +297,7 @@ func (b *backend) nonLockedAWSPublicCertificateEntry(s logical.Storage, certName
}
if persistNeeded {
if err := b.nonLockedSetAWSPublicCertificateEntry(s, certName, &certEntry); err != nil {
if err := b.nonLockedSetAWSPublicCertificateEntry(ctx, s, certName, &certEntry); err != nil {
return nil, err
}
}
@ -318,7 +317,7 @@ func (b *backend) pathConfigCertificateDelete(ctx context.Context, req *logical.
return logical.ErrorResponse("missing cert_name"), nil
}
return nil, req.Storage.Delete("config/certificate/" + certName)
return nil, req.Storage.Delete(ctx, "config/certificate/"+certName)
}
// pathConfigCertificateRead is used to view the configured AWS Public Key that
@ -329,7 +328,7 @@ func (b *backend) pathConfigCertificateRead(ctx context.Context, req *logical.Re
return logical.ErrorResponse("missing cert_name"), nil
}
certificateEntry, err := b.lockedAWSPublicCertificateEntry(req.Storage, certName)
certificateEntry, err := b.lockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
if err != nil {
return nil, err
}
@ -338,7 +337,10 @@ func (b *backend) pathConfigCertificateRead(ctx context.Context, req *logical.Re
}
return &logical.Response{
Data: structs.New(certificateEntry).Map(),
Data: map[string]interface{}{
"aws_public_cert": certificateEntry.AWSPublicCert,
"type": certificateEntry.Type,
},
}, nil
}
@ -354,7 +356,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(ctx context.Context, req *lo
defer b.configMutex.Unlock()
// Check if there is already a certificate entry registered
certEntry, err := b.nonLockedAWSPublicCertificateEntry(req.Storage, certName)
certEntry, err := b.nonLockedAWSPublicCertificateEntry(ctx, req.Storage, certName)
if err != nil {
return nil, err
}
@ -406,7 +408,7 @@ func (b *backend) pathConfigCertificateCreateUpdate(ctx context.Context, req *lo
}
// If none of the checks fail, save the provided certificate
if err := b.nonLockedSetAWSPublicCertificateEntry(req.Storage, certName, certEntry); err != nil {
if err := b.nonLockedSetAWSPublicCertificateEntry(ctx, req.Storage, certName, certEntry); err != nil {
return nil, err
}

View File

@ -66,7 +66,7 @@ func pathConfigClient(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathConfigClientExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedClientConfigEntry(req.Storage)
entry, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return false, err
}
@ -74,16 +74,16 @@ func (b *backend) pathConfigClientExistenceCheck(ctx context.Context, req *logic
}
// Fetch the client configuration required to access the AWS API, after acquiring an exclusive lock.
func (b *backend) lockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
func (b *backend) lockedClientConfigEntry(ctx context.Context, s logical.Storage) (*clientConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedClientConfigEntry(s)
return b.nonLockedClientConfigEntry(ctx, s)
}
// Fetch the client configuration required to access the AWS API.
func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig, error) {
entry, err := s.Get("config/client")
func (b *backend) nonLockedClientConfigEntry(ctx context.Context, s logical.Storage) (*clientConfig, error) {
entry, err := s.Get(ctx, "config/client")
if err != nil {
return nil, err
}
@ -99,7 +99,7 @@ func (b *backend) nonLockedClientConfigEntry(s logical.Storage) (*clientConfig,
}
func (b *backend) pathConfigClientRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.lockedClientConfigEntry(req.Storage)
clientConfig, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -117,7 +117,7 @@ func (b *backend) pathConfigClientDelete(ctx context.Context, req *logical.Reque
b.configMutex.Lock()
defer b.configMutex.Unlock()
if err := req.Storage.Delete("config/client"); err != nil {
if err := req.Storage.Delete(ctx, "config/client"); err != nil {
return nil, err
}
@ -139,7 +139,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.nonLockedClientConfigEntry(req.Storage)
configEntry, err := b.nonLockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -231,7 +231,7 @@ func (b *backend) pathConfigClientCreateUpdate(ctx context.Context, req *logical
}
if changedCreds || changedOtherConfig || req.Operation == logical.CreateOperation {
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
}
@ -261,7 +261,7 @@ Configure AWS IAM credentials that are used to query instance and role details f
`
const pathConfigClientHelpDesc = `
The aws-ec2 auth backend makes AWS API queries to retrieve information
The aws-ec2 auth method makes AWS API queries to retrieve information
regarding EC2 instances that perform login operations. The 'aws_secret_key' and
'aws_access_key' parameters configured here should map to an AWS IAM user that
has permission to make the following API queries:

View File

@ -16,7 +16,8 @@ func TestBackend_pathConfigClient(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}

View File

@ -66,7 +66,7 @@ func (b *backend) pathConfigStsExistenceCheck(ctx context.Context, req *logical.
return false, fmt.Errorf("missing account_id")
}
entry, err := b.lockedAwsStsEntry(req.Storage, accountID)
entry, err := b.lockedAwsStsEntry(ctx, req.Storage, accountID)
if err != nil {
return false, err
}
@ -78,7 +78,7 @@ func (b *backend) pathConfigStsExistenceCheck(ctx context.Context, req *logical.
func (b *backend) pathStsList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
sts, err := req.Storage.List("config/sts/")
sts, err := req.Storage.List(ctx, "config/sts/")
if err != nil {
return nil, err
}
@ -88,7 +88,7 @@ func (b *backend) pathStsList(ctx context.Context, req *logical.Request, data *f
// nonLockedSetAwsStsEntry creates or updates an STS role association with the given accountID
// This method does not acquire the write lock before creating or updating. If locking is
// desired, use lockedSetAwsStsEntry instead
func (b *backend) nonLockedSetAwsStsEntry(s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
func (b *backend) nonLockedSetAwsStsEntry(ctx context.Context, s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
if accountID == "" {
return fmt.Errorf("missing AWS account ID")
}
@ -106,12 +106,12 @@ func (b *backend) nonLockedSetAwsStsEntry(s logical.Storage, accountID string, s
return fmt.Errorf("failed to create storage entry for AWS STS configuration")
}
return s.Put(entry)
return s.Put(ctx, entry)
}
// lockedSetAwsStsEntry creates or updates an STS role association with the given accountID
// This method acquires the write lock before creating or updating the STS entry.
func (b *backend) lockedSetAwsStsEntry(s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
func (b *backend) lockedSetAwsStsEntry(ctx context.Context, s logical.Storage, accountID string, stsEntry *awsStsEntry) error {
if accountID == "" {
return fmt.Errorf("missing AWS account ID")
}
@ -123,14 +123,14 @@ func (b *backend) lockedSetAwsStsEntry(s logical.Storage, accountID string, stsE
b.configMutex.Lock()
defer b.configMutex.Unlock()
return b.nonLockedSetAwsStsEntry(s, accountID, stsEntry)
return b.nonLockedSetAwsStsEntry(ctx, s, accountID, stsEntry)
}
// nonLockedAwsStsEntry returns the STS role associated with the given accountID.
// This method does not acquire the read lock before returning information. If locking is
// desired, use lockedAwsStsEntry instead
func (b *backend) nonLockedAwsStsEntry(s logical.Storage, accountID string) (*awsStsEntry, error) {
entry, err := s.Get("config/sts/" + accountID)
func (b *backend) nonLockedAwsStsEntry(ctx context.Context, s logical.Storage, accountID string) (*awsStsEntry, error) {
entry, err := s.Get(ctx, "config/sts/"+accountID)
if err != nil {
return nil, err
}
@ -147,11 +147,11 @@ func (b *backend) nonLockedAwsStsEntry(s logical.Storage, accountID string) (*aw
// lockedAwsStsEntry returns the STS role associated with the given accountID.
// This method acquires the read lock before returning the association.
func (b *backend) lockedAwsStsEntry(s logical.Storage, accountID string) (*awsStsEntry, error) {
func (b *backend) lockedAwsStsEntry(ctx context.Context, s logical.Storage, accountID string) (*awsStsEntry, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedAwsStsEntry(s, accountID)
return b.nonLockedAwsStsEntry(ctx, s, accountID)
}
// pathConfigStsRead is used to return information about an STS role/AWS accountID association
@ -161,7 +161,7 @@ func (b *backend) pathConfigStsRead(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse("missing account id"), nil
}
stsEntry, err := b.lockedAwsStsEntry(req.Storage, accountID)
stsEntry, err := b.lockedAwsStsEntry(ctx, req.Storage, accountID)
if err != nil {
return nil, err
}
@ -185,7 +185,7 @@ func (b *backend) pathConfigStsCreateUpdate(ctx context.Context, req *logical.Re
defer b.configMutex.Unlock()
// Check if an STS role is already registered
stsEntry, err := b.nonLockedAwsStsEntry(req.Storage, accountID)
stsEntry, err := b.nonLockedAwsStsEntry(ctx, req.Storage, accountID)
if err != nil {
return nil, err
}
@ -206,7 +206,7 @@ func (b *backend) pathConfigStsCreateUpdate(ctx context.Context, req *logical.Re
}
// save the provided STS role
if err := b.nonLockedSetAwsStsEntry(req.Storage, accountID, stsEntry); err != nil {
if err := b.nonLockedSetAwsStsEntry(ctx, req.Storage, accountID, stsEntry); err != nil {
return nil, err
}
@ -223,7 +223,7 @@ func (b *backend) pathConfigStsDelete(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("missing account id"), nil
}
return nil, req.Storage.Delete("config/sts/" + accountID)
return nil, req.Storage.Delete(ctx, "config/sts/"+accountID)
}
const pathConfigStsSyn = `

View File

@ -45,22 +45,22 @@ expiration, before it is removed from the backend storage.`,
}
func (b *backend) pathConfigTidyIdentityWhitelistExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedConfigTidyIdentities(req.Storage)
entry, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) lockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
func (b *backend) lockedConfigTidyIdentities(ctx context.Context, s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedConfigTidyIdentities(s)
return b.nonLockedConfigTidyIdentities(ctx, s)
}
func (b *backend) nonLockedConfigTidyIdentities(s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
entry, err := s.Get(identityWhitelistConfigPath)
func (b *backend) nonLockedConfigTidyIdentities(ctx context.Context, s logical.Storage) (*tidyWhitelistIdentityConfig, error) {
entry, err := s.Get(ctx, identityWhitelistConfigPath)
if err != nil {
return nil, err
}
@ -79,7 +79,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.nonLockedConfigTidyIdentities(req.Storage)
configEntry, err := b.nonLockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -106,7 +106,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -114,7 +114,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistCreateUpdate(ctx context.Contex
}
func (b *backend) pathConfigTidyIdentityWhitelistRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.lockedConfigTidyIdentities(req.Storage)
clientConfig, err := b.lockedConfigTidyIdentities(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -131,7 +131,7 @@ func (b *backend) pathConfigTidyIdentityWhitelistDelete(ctx context.Context, req
b.configMutex.Lock()
defer b.configMutex.Unlock()
return nil, req.Storage.Delete(identityWhitelistConfigPath)
return nil, req.Storage.Delete(ctx, identityWhitelistConfigPath)
}
type tidyWhitelistIdentityConfig struct {

View File

@ -47,22 +47,22 @@ Defaults to 4320h (180 days).`,
}
func (b *backend) pathConfigTidyRoletagBlacklistExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedConfigTidyRoleTags(req.Storage)
entry, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return false, err
}
return entry != nil, nil
}
func (b *backend) lockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
func (b *backend) lockedConfigTidyRoleTags(ctx context.Context, s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
b.configMutex.RLock()
defer b.configMutex.RUnlock()
return b.nonLockedConfigTidyRoleTags(s)
return b.nonLockedConfigTidyRoleTags(ctx, s)
}
func (b *backend) nonLockedConfigTidyRoleTags(s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
entry, err := s.Get(roletagBlacklistConfigPath)
func (b *backend) nonLockedConfigTidyRoleTags(ctx context.Context, s logical.Storage) (*tidyBlacklistRoleTagConfig, error) {
entry, err := s.Get(ctx, roletagBlacklistConfigPath)
if err != nil {
return nil, err
}
@ -82,7 +82,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
b.configMutex.Lock()
defer b.configMutex.Unlock()
configEntry, err := b.nonLockedConfigTidyRoleTags(req.Storage)
configEntry, err := b.nonLockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -107,7 +107,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -115,7 +115,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistCreateUpdate(ctx context.Context
}
func (b *backend) pathConfigTidyRoletagBlacklistRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
clientConfig, err := b.lockedConfigTidyRoleTags(req.Storage)
clientConfig, err := b.lockedConfigTidyRoleTags(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -132,7 +132,7 @@ func (b *backend) pathConfigTidyRoletagBlacklistDelete(ctx context.Context, req
b.configMutex.Lock()
defer b.configMutex.Unlock()
return nil, req.Storage.Delete(roletagBlacklistConfigPath)
return nil, req.Storage.Delete(ctx, roletagBlacklistConfigPath)
}
type tidyBlacklistRoleTagConfig struct {

View File

@ -46,7 +46,7 @@ func pathListIdentityWhitelist(b *backend) *framework.Path {
// pathWhitelistIdentitiesList is used to list all the instance IDs that are present
// in the identity whitelist. This will list both valid and expired entries.
func (b *backend) pathWhitelistIdentitiesList(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
identities, err := req.Storage.List("whitelist/identity/")
identities, err := req.Storage.List(ctx, "whitelist/identity/")
if err != nil {
return nil, err
}
@ -54,8 +54,8 @@ func (b *backend) pathWhitelistIdentitiesList(ctx context.Context, req *logical.
}
// Fetch an item from the whitelist given an instance ID.
func whitelistIdentityEntry(s logical.Storage, instanceID string) (*whitelistIdentity, error) {
entry, err := s.Get("whitelist/identity/" + instanceID)
func whitelistIdentityEntry(ctx context.Context, s logical.Storage, instanceID string) (*whitelistIdentity, error) {
entry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
if err != nil {
return nil, err
}
@ -72,13 +72,13 @@ func whitelistIdentityEntry(s logical.Storage, instanceID string) (*whitelistIde
// Stores an instance ID and the information required to validate further login/renewal attempts from
// the same instance ID.
func setWhitelistIdentityEntry(s logical.Storage, instanceID string, identity *whitelistIdentity) error {
func setWhitelistIdentityEntry(ctx context.Context, s logical.Storage, instanceID string, identity *whitelistIdentity) error {
entry, err := logical.StorageEntryJSON("whitelist/identity/"+instanceID, identity)
if err != nil {
return err
}
if err := s.Put(entry); err != nil {
if err := s.Put(ctx, entry); err != nil {
return err
}
return nil
@ -91,7 +91,7 @@ func (b *backend) pathIdentityWhitelistDelete(ctx context.Context, req *logical.
return logical.ErrorResponse("missing instance_id"), nil
}
return nil, req.Storage.Delete("whitelist/identity/" + instanceID)
return nil, req.Storage.Delete(ctx, "whitelist/identity/"+instanceID)
}
// pathIdentityWhitelistRead is used to view an entry in the identity whitelist given an instance ID.
@ -101,7 +101,7 @@ func (b *backend) pathIdentityWhitelistRead(ctx context.Context, req *logical.Re
return logical.ErrorResponse("missing instance_id"), nil
}
entry, err := whitelistIdentityEntry(req.Storage, instanceID)
entry, err := whitelistIdentityEntry(ctx, req.Storage, instanceID)
if err != nil {
return nil, err
}

View File

@ -5,6 +5,7 @@ import (
"crypto/subtle"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"encoding/xml"
"fmt"
@ -153,9 +154,9 @@ func (b *backend) instanceIamRoleARN(iamClient *iam.IAM, instanceProfileName str
// validateInstance queries the status of the EC2 instance using AWS EC2 API
// and checks if the instance is running and is healthy
func (b *backend) validateInstance(s logical.Storage, instanceID, region, accountID string) (*ec2.Instance, error) {
func (b *backend) validateInstance(ctx context.Context, s logical.Storage, instanceID, region, accountID string) (*ec2.Instance, error) {
// Create an EC2 client to pull the instance information
ec2Client, err := b.clientEC2(s, region, accountID)
ec2Client, err := b.clientEC2(ctx, s, region, accountID)
if err != nil {
return nil, err
}
@ -255,7 +256,7 @@ func validateMetadata(clientNonce, pendingTime string, storedIdentity *whitelist
// Verifies the integrity of the instance identity document using its SHA256
// RSA signature. After verification, returns the unmarshaled instance identity
// document.
func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityBytes, signatureBytes []byte) (*identityDocument, error) {
func (b *backend) verifyInstanceIdentitySignature(ctx context.Context, s logical.Storage, identityBytes, signatureBytes []byte) (*identityDocument, error) {
if len(identityBytes) == 0 {
return nil, fmt.Errorf("missing instance identity document")
}
@ -269,7 +270,7 @@ func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityByt
// certificate and all the registered certificates via
// 'config/certificate/<cert_name>' endpoint, for verifying the RSA
// digest.
publicCerts, err := b.awsPublicCertificates(s, false)
publicCerts, err := b.awsPublicCertificates(ctx, s, false)
if err != nil {
return nil, err
}
@ -296,7 +297,7 @@ func (b *backend) verifyInstanceIdentitySignature(s logical.Storage, identityByt
// Verifies the correctness of the authenticated attributes present in the PKCS#7
// signature. After verification, extracts the instance identity document from the
// signature, parses it and returns it.
func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
func (b *backend) parseIdentityDocument(ctx context.Context, s logical.Storage, pkcs7B64 string) (*identityDocument, error) {
// Insert the header and footer for the signature to be able to pem decode it
pkcs7B64 = fmt.Sprintf("-----BEGIN PKCS7-----\n%s\n-----END PKCS7-----", pkcs7B64)
@ -315,7 +316,7 @@ func (b *backend) parseIdentityDocument(s logical.Storage, pkcs7B64 string) (*id
// Get the public certificates that are used to verify the signature.
// This returns a slice of certificates containing the default certificate
// and all the registered certificates via 'config/certificate/<cert_name>' endpoint
publicCerts, err := b.awsPublicCertificates(s, true)
publicCerts, err := b.awsPublicCertificates(ctx, s, true)
if err != nil {
return nil, err
}
@ -371,7 +372,7 @@ func (b *backend) pathLoginUpdate(ctx context.Context, req *logical.Request, dat
// error that means the instance doesn't meet the role requirements
// The second error return value indicates whether there's an error in even
// trying to validate those requirements
func (b *backend) verifyInstanceMeetsRoleRequirements(
func (b *backend) verifyInstanceMeetsRoleRequirements(ctx context.Context,
s logical.Storage, instance *ec2.Instance, roleEntry *awsRoleEntry, roleName string, identityDoc *identityDocument) (error, error) {
switch {
@ -469,7 +470,7 @@ func (b *backend) verifyInstanceMeetsRoleRequirements(
}
// Use instance profile ARN to fetch the associated role ARN
iamClient, err := b.clientIAM(s, identityDoc.Region, identityDoc.AccountID)
iamClient, err := b.clientIAM(ctx, s, identityDoc.Region, identityDoc.AccountID)
if err != nil {
return nil, fmt.Errorf("could not fetch IAM client: %v", err)
} else if iamClient == nil {
@ -529,7 +530,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
// Verify the signature of the identity document and unmarshal it
var identityDocParsed *identityDocument
if pkcs7B64 != "" {
identityDocParsed, err = b.parseIdentityDocument(req.Storage, pkcs7B64)
identityDocParsed, err = b.parseIdentityDocument(ctx, req.Storage, pkcs7B64)
if err != nil {
return nil, err
}
@ -537,7 +538,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("failed to verify the instance identity document using pkcs7"), nil
}
} else {
identityDocParsed, err = b.verifyInstanceIdentitySignature(req.Storage, identityDocBytes, signatureBytes)
identityDocParsed, err = b.verifyInstanceIdentitySignature(ctx, req.Storage, identityDocBytes, signatureBytes)
if err != nil {
return nil, err
}
@ -565,7 +566,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
}
// Get the entry for the role used by the instance
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@ -580,7 +581,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
// Validate the instance ID by making a call to AWS EC2 DescribeInstances API
// and fetching the instance description. Validation succeeds only if the
// instance is in 'running' state.
instance, err := b.validateInstance(req.Storage, identityDocParsed.InstanceID, identityDocParsed.Region, identityDocParsed.AccountID)
instance, err := b.validateInstance(ctx, req.Storage, identityDocParsed.InstanceID, identityDocParsed.Region, identityDocParsed.AccountID)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to verify instance ID: %v", err)), nil
}
@ -591,7 +592,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
return logical.ErrorResponse(fmt.Sprintf("Region %q does not satisfy the constraint on role %q", identityDocParsed.Region, roleName)), nil
}
validationError, err := b.verifyInstanceMeetsRoleRequirements(req.Storage, instance, roleEntry, roleName, identityDocParsed)
validationError, err := b.verifyInstanceMeetsRoleRequirements(ctx, req.Storage, instance, roleEntry, roleName, identityDocParsed)
if err != nil {
return nil, err
}
@ -600,7 +601,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
}
// Get the entry from the identity whitelist, if there is one
storedIdentity, err := whitelistIdentityEntry(req.Storage, identityDocParsed.InstanceID)
storedIdentity, err := whitelistIdentityEntry(ctx, req.Storage, identityDocParsed.InstanceID)
if err != nil {
return nil, err
}
@ -681,7 +682,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
rTagMaxTTL := time.Duration(0)
var roleTagResp *roleTagLoginResponse
if roleEntry.RoleTag != "" {
roleTagResp, err := b.handleRoleTagLogin(req.Storage, roleName, roleEntry, instance)
roleTagResp, err := b.handleRoleTagLogin(ctx, req.Storage, roleName, roleEntry, instance)
if err != nil {
return nil, err
}
@ -749,7 +750,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("client nonce exceeding the limit of 128 characters"), nil
}
if err = setWhitelistIdentityEntry(req.Storage, identityDocParsed.InstanceID, storedIdentity); err != nil {
if err = setWhitelistIdentityEntry(ctx, req.Storage, identityDocParsed.InstanceID, storedIdentity); err != nil {
return nil, err
}
@ -799,7 +800,7 @@ func (b *backend) pathLoginUpdateEc2(ctx context.Context, req *logical.Request,
// handleRoleTagLogin is used to fetch the role tag of the instance and
// verifies it to be correct. Then the policies for the login request will be
// set off of the role tag, if certain creteria satisfies.
func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEntry *awsRoleEntry, instance *ec2.Instance) (*roleTagLoginResponse, error) {
func (b *backend) handleRoleTagLogin(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry, instance *ec2.Instance) (*roleTagLoginResponse, error) {
if roleEntry == nil {
return nil, fmt.Errorf("nil role entry")
}
@ -831,7 +832,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEnt
}
// Parse the role tag into a struct, extract the plaintext part of it and verify its HMAC
rTag, err := b.parseAndVerifyRoleTagValue(s, rTagValue)
rTag, err := b.parseAndVerifyRoleTagValue(ctx, s, rTagValue)
if err != nil {
return nil, err
}
@ -848,7 +849,7 @@ func (b *backend) handleRoleTagLogin(s logical.Storage, roleName string, roleEnt
}
// Check if the role tag is blacklisted
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(s, rTagValue)
blacklistEntry, err := b.lockedBlacklistRoleTagEntry(ctx, s, rTagValue)
if err != nil {
return nil, err
}
@ -895,7 +896,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if roleName == "" {
return nil, fmt.Errorf("error retrieving role_name during renewal")
}
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@ -923,7 +924,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if !ok {
return nil, fmt.Errorf("no inferred AWS region in auth metadata")
}
_, err := b.validateInstance(req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
_, err := b.validateInstance(ctx, req.Storage, instanceID, instanceRegion, req.Auth.Metadata["account_id"])
if err != nil {
return nil, fmt.Errorf("failed to verify instance ID %q: %v", instanceID, err)
}
@ -955,7 +956,7 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
if err != nil {
return nil, fmt.Errorf("error parsing ARN %q: %v", canonicalArn, err)
}
fullArn, err = b.fullArn(entity, req.Storage)
fullArn, err = b.fullArn(ctx, entity, req.Storage)
if err != nil {
return nil, fmt.Errorf("error looking up full ARN of entity %v: %v", entity, err)
}
@ -974,12 +975,17 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d
}
}
resp, err := framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(ctx, req, data)
if err != nil {
return nil, err
// If a period is provided, set that as part of resp.Auth.Period and return a
// response immediately. Let expiration manager handle renewal from there on.
if roleEntry.Period > time.Duration(0) {
resp := &logical.Response{
Auth: req.Auth,
}
resp.Auth.Period = roleEntry.Period
return resp, nil
}
resp.Auth.Period = roleEntry.Period
return resp, nil
return framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(ctx, req, data)
}
func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1002,12 +1008,12 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
}
// Cross check that the instance is still in 'running' state
_, err := b.validateInstance(req.Storage, instanceID, region, accountID)
_, err := b.validateInstance(ctx, req.Storage, instanceID, region, accountID)
if err != nil {
return nil, fmt.Errorf("failed to verify instance ID %q: %q", instanceID, err)
}
storedIdentity, err := whitelistIdentityEntry(req.Storage, instanceID)
storedIdentity, err := whitelistIdentityEntry(ctx, req.Storage, instanceID)
if err != nil {
return nil, err
}
@ -1016,7 +1022,7 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
}
// Ensure that role entry is not deleted
roleEntry, err := b.lockedAWSRole(req.Storage, storedIdentity.Role)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, storedIdentity.Role)
if err != nil {
return nil, err
}
@ -1055,16 +1061,21 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d
// Updating the expiration time is required for the tidy operation on the
// whitelist identity storage items
if err = setWhitelistIdentityEntry(req.Storage, instanceID, storedIdentity); err != nil {
if err = setWhitelistIdentityEntry(ctx, req.Storage, instanceID, storedIdentity); err != nil {
return nil, err
}
resp, err := framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(ctx, req, data)
if err != nil {
return nil, err
// If a period is provided, set that as part of resp.Auth.Period and return a
// response immediately. Let expiration manager handle renewal from there on.
if roleEntry.Period > time.Duration(0) {
resp := &logical.Response{
Auth: req.Auth,
}
resp.Auth.Period = roleEntry.Period
return resp, nil
}
resp.Auth.Period = roleEntry.Period
return resp, nil
return framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(ctx, req, data)
}
func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -1116,7 +1127,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
return logical.ErrorResponse("nil response when parsing iam_request_headers"), nil
}
config, err := b.lockedClientConfigEntry(req.Storage)
config, err := b.lockedClientConfigEntry(ctx, req.Storage)
if err != nil {
return logical.ErrorResponse("error getting configuration"), nil
}
@ -1164,7 +1175,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
roleName = entity.FriendlyName
}
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@ -1189,7 +1200,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
if strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
fullArn := b.getCachedUserId(callerUniqueId)
if fullArn == "" {
fullArn, err = b.fullArn(entity, req.Storage)
fullArn, err = b.fullArn(ctx, entity, req.Storage)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("error looking up full ARN of entity %v: %v", entity, err)), nil
}
@ -1213,7 +1224,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
inferredEntityType := ""
inferredEntityID := ""
if roleEntry.InferredEntityType == ec2EntityType {
instance, err := b.validateInstance(req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
instance, err := b.validateInstance(ctx, req.Storage, entity.SessionInfo, roleEntry.InferredAWSRegion, callerID.Account)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to verify %s as a valid EC2 instance in region %s", entity.SessionInfo, roleEntry.InferredAWSRegion)), nil
}
@ -1228,7 +1239,7 @@ func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request,
PendingTime: instance.LaunchTime.Format(time.RFC3339),
}
validationError, err := b.verifyInstanceMeetsRoleRequirements(req.Storage, instance, roleEntry, roleName, identityDoc)
validationError, err := b.verifyInstanceMeetsRoleRequirements(ctx, req.Storage, instance, roleEntry, roleName, identityDoc)
if err != nil {
return nil, err
}
@ -1467,11 +1478,15 @@ func parseIamRequestHeaders(headersB64 string) (http.Header, error) {
switch typedValue := v.(type) {
case string:
headers.Add(k, typedValue)
case json.Number:
headers.Add(k, typedValue.String())
case []interface{}:
for _, individualVal := range typedValue {
switch possibleStrVal := individualVal.(type) {
case string:
headers.Add(k, possibleStrVal)
case json.Number:
headers.Add(k, possibleStrVal.String())
default:
return nil, fmt.Errorf("header %q contains value %q that has type %s, not string", k, individualVal, reflect.TypeOf(individualVal))
}
@ -1572,9 +1587,9 @@ func (e *iamEntity) canonicalArn() string {
}
// This returns the "full" ARN of an iamEntity, how it would be referred to in AWS proper
func (b *backend) fullArn(e *iamEntity, s logical.Storage) (string, error) {
func (b *backend) fullArn(ctx context.Context, e *iamEntity, s logical.Storage) (string, error) {
// Not assuming path is reliable for any entity types
client, err := b.clientIAM(s, getAnyRegionForAwsPartition(e.Partition).ID(), e.AccountNumber)
client, err := b.clientIAM(ctx, s, getAnyRegionForAwsPartition(e.Partition).ID(), e.AccountNumber)
if err != nil {
return "", fmt.Errorf("error creating IAM client: %v", err)
}

View File

@ -204,7 +204,7 @@ func pathListRoles(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
entry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return false, err
}
@ -213,13 +213,13 @@ func (b *backend) pathRoleExistenceCheck(ctx context.Context, req *logical.Reque
// lockedAWSRole returns the properties set on the given role. This method
// acquires the read lock before reading the role from the storage.
func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEntry, error) {
func (b *backend) lockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}
b.roleMutex.RLock()
roleEntry, err := b.nonLockedAWSRole(s, roleName)
roleEntry, err := b.nonLockedAWSRole(ctx, s, roleName)
// we manually unlock rather than defer the unlock because we might need to grab
// a read/write lock in the upgrade path
b.roleMutex.RUnlock()
@ -229,7 +229,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
if roleEntry == nil {
return nil, nil
}
needUpgrade, err := b.upgradeRoleEntry(s, roleEntry)
needUpgrade, err := b.upgradeRoleEntry(ctx, s, roleEntry)
if err != nil {
return nil, fmt.Errorf("error upgrading roleEntry: %v", err)
}
@ -238,7 +238,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
defer b.roleMutex.Unlock()
// Now that we have a R/W lock, we need to re-read the role entry in case it was
// written to between releasing the read lock and acquiring the write lock
roleEntry, err = b.nonLockedAWSRole(s, roleName)
roleEntry, err = b.nonLockedAWSRole(ctx, s, roleName)
if err != nil {
return nil, err
}
@ -247,11 +247,11 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
return nil, nil
}
// now re-check to see if we need to upgrade
if needUpgrade, err = b.upgradeRoleEntry(s, roleEntry); err != nil {
if needUpgrade, err = b.upgradeRoleEntry(ctx, s, roleEntry); err != nil {
return nil, fmt.Errorf("error upgrading roleEntry: %v", err)
}
if needUpgrade {
if err = b.nonLockedSetAWSRole(s, roleName, roleEntry); err != nil {
if err = b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry); err != nil {
return nil, fmt.Errorf("error saving upgraded roleEntry: %v", err)
}
}
@ -261,7 +261,7 @@ func (b *backend) lockedAWSRole(s logical.Storage, roleName string) (*awsRoleEnt
// lockedSetAWSRole creates or updates a role in the storage. This method
// acquires the write lock before creating or updating the role at the storage.
func (b *backend) lockedSetAWSRole(s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
func (b *backend) lockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string, roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
}
@ -273,13 +273,13 @@ func (b *backend) lockedSetAWSRole(s logical.Storage, roleName string, roleEntry
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
return b.nonLockedSetAWSRole(s, roleName, roleEntry)
return b.nonLockedSetAWSRole(ctx, s, roleName, roleEntry)
}
// nonLockedSetAWSRole creates or updates a role in the storage. This method
// does not acquire the write lock before reading the role from the storage. If
// locking is desired, use lockedSetAWSRole instead.
func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
func (b *backend) nonLockedSetAWSRole(ctx context.Context, s logical.Storage, roleName string,
roleEntry *awsRoleEntry) error {
if roleName == "" {
return fmt.Errorf("missing role name")
@ -294,7 +294,7 @@ func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
return err
}
if err := s.Put(entry); err != nil {
if err := s.Put(ctx, entry); err != nil {
return err
}
@ -303,7 +303,7 @@ func (b *backend) nonLockedSetAWSRole(s logical.Storage, roleName string,
// If needed, updates the role entry and returns a bool indicating if it was updated
// (and thus needs to be persisted)
func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
func (b *backend) upgradeRoleEntry(ctx context.Context, s logical.Storage, roleEntry *awsRoleEntry) (bool, error) {
if roleEntry == nil {
return false, fmt.Errorf("received nil roleEntry")
}
@ -331,7 +331,7 @@ func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (
roleEntry.BoundIamPrincipalARN != "" &&
roleEntry.BoundIamPrincipalID == "" &&
!strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
principalId, err := b.resolveArnToUniqueIDFunc(s, roleEntry.BoundIamPrincipalARN)
principalId, err := b.resolveArnToUniqueIDFunc(ctx, s, roleEntry.BoundIamPrincipalARN)
if err != nil {
return false, err
}
@ -349,12 +349,12 @@ func (b *backend) upgradeRoleEntry(s logical.Storage, roleEntry *awsRoleEntry) (
// This method also does NOT check to see if a role upgrade is required. It is
// the responsibility of the caller to check if a role upgrade is required and,
// if so, to upgrade the role
func (b *backend) nonLockedAWSRole(s logical.Storage, roleName string) (*awsRoleEntry, error) {
func (b *backend) nonLockedAWSRole(ctx context.Context, s logical.Storage, roleName string) (*awsRoleEntry, error) {
if roleName == "" {
return nil, fmt.Errorf("missing role name")
}
entry, err := s.Get("role/" + strings.ToLower(roleName))
entry, err := s.Get(ctx, "role/"+strings.ToLower(roleName))
if err != nil {
return nil, err
}
@ -380,7 +380,7 @@ func (b *backend) pathRoleDelete(ctx context.Context, req *logical.Request, data
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
return nil, req.Storage.Delete("role/" + strings.ToLower(roleName))
return nil, req.Storage.Delete(ctx, "role/"+strings.ToLower(roleName))
}
// pathRoleList is used to list all the AMI IDs registered with Vault.
@ -388,7 +388,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
b.roleMutex.RLock()
defer b.roleMutex.RUnlock()
roles, err := req.Storage.List("role/")
roles, err := req.Storage.List(ctx, "role/")
if err != nil {
return nil, err
}
@ -397,7 +397,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, data *
// pathRoleRead is used to view the information registered for a given AMI ID.
func (b *backend) pathRoleRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
roleEntry, err := b.lockedAWSRole(req.Storage, strings.ToLower(data.Get("role").(string)))
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, strings.ToLower(data.Get("role").(string)))
if err != nil {
return nil, err
}
@ -431,19 +431,19 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
b.roleMutex.Lock()
defer b.roleMutex.Unlock()
roleEntry, err := b.nonLockedAWSRole(req.Storage, roleName)
roleEntry, err := b.nonLockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
if roleEntry == nil {
roleEntry = &awsRoleEntry{}
} else {
needUpdate, err := b.upgradeRoleEntry(req.Storage, roleEntry)
needUpdate, err := b.upgradeRoleEntry(ctx, req.Storage, roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to update roleEntry: %v", err)), nil
}
if needUpdate {
err = b.nonLockedSetAWSRole(req.Storage, roleName, roleEntry)
err = b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed to save upgraded roleEntry: %v", err)), nil
}
@ -500,8 +500,8 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
// This allows the user to sumbit an update with the same ARN to force Vault
// to re-resolve the ARN to the unique ID, in case an entity was deleted and
// recreated
if roleEntry.ResolveAWSUniqueIDs && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
principalID, err := b.resolveArnToUniqueIDFunc(req.Storage, principalARN)
if roleEntry.ResolveAWSUniqueIDs && roleEntry.BoundIamPrincipalARN != "" && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, principalARN)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("failed updating the unique ID of ARN %#v: %#v", principalARN, err)), nil
}
@ -512,7 +512,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
} else if roleEntry.ResolveAWSUniqueIDs && roleEntry.BoundIamPrincipalARN != "" && !strings.HasSuffix(roleEntry.BoundIamPrincipalARN, "*") {
// we're turning on resolution on this role, so ensure we update it
principalID, err := b.resolveArnToUniqueIDFunc(req.Storage, roleEntry.BoundIamPrincipalARN)
principalID, err := b.resolveArnToUniqueIDFunc(ctx, req.Storage, roleEntry.BoundIamPrincipalARN)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("unable to resolve ARN %#v to internal ID: %#v", roleEntry.BoundIamPrincipalARN, err)), nil
}
@ -731,7 +731,7 @@ func (b *backend) pathRoleCreateUpdate(ctx context.Context, req *logical.Request
}
}
if err := b.nonLockedSetAWSRole(req.Storage, roleName, roleEntry); err != nil {
if err := b.nonLockedSetAWSRole(ctx, req.Storage, roleName, roleEntry); err != nil {
return nil, err
}

View File

@ -77,7 +77,7 @@ func (b *backend) pathRoleTagUpdate(ctx context.Context, req *logical.Request, d
}
// Fetch the role entry
roleEntry, err := b.lockedAWSRole(req.Storage, roleName)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, roleName)
if err != nil {
return nil, err
}
@ -288,7 +288,7 @@ func prepareRoleTagPlaintextValue(rTag *roleTag) (string, error) {
// Parses the tag from string form into a struct form. This method
// also verifies the correctness of the parsed role tag.
func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*roleTag, error) {
func (b *backend) parseAndVerifyRoleTagValue(ctx context.Context, s logical.Storage, tag string) (*roleTag, error) {
tagItems := strings.Split(tag, ":")
// Tag must contain version, nonce, policies and HMAC
@ -349,7 +349,7 @@ func (b *backend) parseAndVerifyRoleTagValue(s logical.Storage, tag string) (*ro
return nil, fmt.Errorf("missing role name")
}
roleEntry, err := b.lockedAWSRole(s, rTag.Role)
roleEntry, err := b.lockedAWSRole(ctx, s, rTag.Role)
if err != nil {
return nil, err
}

View File

@ -20,7 +20,8 @@ func TestBackend_pathRoleEc2(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -117,6 +118,20 @@ func TestBackend_pathRoleEc2(t *testing.T) {
t.Fatal(err)
}
data["bound_iam_principal_arn"] = ""
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.UpdateOperation,
Path: "role/ami-abcd456",
Data: data,
Storage: storage,
})
if err != nil {
t.Fatal(err)
}
if resp != nil && resp.IsError() {
t.Fatalf("failed to update role with empty bound_iam_principal_arn: %s", resp.Data["error"])
}
resp, err = b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ListOperation,
Path: "roles",
@ -164,7 +179,7 @@ func Test_enableIamIDResolution(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -239,7 +254,7 @@ func TestBackend_pathIam(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -403,7 +418,7 @@ func TestBackend_pathRoleMixedTypes(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -509,7 +524,8 @@ func TestAwsEc2_RoleCrud(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -635,7 +651,8 @@ func TestAwsEc2_RoleDurationSeconds(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = b.Setup(config)
err = b.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -679,6 +696,6 @@ func TestAwsEc2_RoleDurationSeconds(t *testing.T) {
}
}
func resolveArnToFakeUniqueId(s logical.Storage, arn string) (string, error) {
func resolveArnToFakeUniqueId(ctx context.Context, s logical.Storage, arn string) (string, error) {
return "FakeUniqueId1", nil
}

View File

@ -50,7 +50,7 @@ func (b *backend) pathRoletagBlacklistsList(ctx context.Context, req *logical.Re
b.blacklistMutex.RLock()
defer b.blacklistMutex.RUnlock()
tags, err := req.Storage.List("blacklist/roletag/")
tags, err := req.Storage.List(ctx, "blacklist/roletag/")
if err != nil {
return nil, err
}
@ -71,15 +71,15 @@ func (b *backend) pathRoletagBlacklistsList(ctx context.Context, req *logical.Re
// Fetch an entry from the role tag blacklist for a given tag.
// This method takes a role tag in its original form and not a base64 encoded form.
func (b *backend) lockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
func (b *backend) lockedBlacklistRoleTagEntry(ctx context.Context, s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
b.blacklistMutex.RLock()
defer b.blacklistMutex.RUnlock()
return b.nonLockedBlacklistRoleTagEntry(s, tag)
return b.nonLockedBlacklistRoleTagEntry(ctx, s, tag)
}
func (b *backend) nonLockedBlacklistRoleTagEntry(s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
entry, err := s.Get("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
func (b *backend) nonLockedBlacklistRoleTagEntry(ctx context.Context, s logical.Storage, tag string) (*roleTagBlacklistEntry, error) {
entry, err := s.Get(ctx, "blacklist/roletag/"+base64.StdEncoding.EncodeToString([]byte(tag)))
if err != nil {
return nil, err
}
@ -104,7 +104,7 @@ func (b *backend) pathRoletagBlacklistDelete(ctx context.Context, req *logical.R
return logical.ErrorResponse("missing role_tag"), nil
}
return nil, req.Storage.Delete("blacklist/roletag/" + base64.StdEncoding.EncodeToString([]byte(tag)))
return nil, req.Storage.Delete(ctx, "blacklist/roletag/"+base64.StdEncoding.EncodeToString([]byte(tag)))
}
// If the given role tag is blacklisted, returns the details of the blacklist entry.
@ -115,7 +115,7 @@ func (b *backend) pathRoletagBlacklistRead(ctx context.Context, req *logical.Req
return logical.ErrorResponse("missing role_tag"), nil
}
entry, err := b.lockedBlacklistRoleTagEntry(req.Storage, tag)
entry, err := b.lockedBlacklistRoleTagEntry(ctx, req.Storage, tag)
if err != nil {
return nil, err
}
@ -154,7 +154,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
}
// Parse and verify the role tag from string form to a struct form and verify it.
rTag, err := b.parseAndVerifyRoleTagValue(req.Storage, tag)
rTag, err := b.parseAndVerifyRoleTagValue(ctx, req.Storage, tag)
if err != nil {
return nil, err
}
@ -163,7 +163,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
}
// Get the entry for the role mentioned in the role tag.
roleEntry, err := b.lockedAWSRole(req.Storage, rTag.Role)
roleEntry, err := b.lockedAWSRole(ctx, req.Storage, rTag.Role)
if err != nil {
return nil, err
}
@ -175,7 +175,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
defer b.blacklistMutex.Unlock()
// Check if the role tag is already blacklisted. If yes, update it.
blEntry, err := b.nonLockedBlacklistRoleTagEntry(req.Storage, tag)
blEntry, err := b.nonLockedBlacklistRoleTagEntry(ctx, req.Storage, tag)
if err != nil {
return nil, err
}
@ -211,7 +211,7 @@ func (b *backend) pathRoletagBlacklistUpdate(ctx context.Context, req *logical.R
}
// Store the blacklist entry.
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -32,7 +32,7 @@ expiration, before it is removed from the backend storage.`,
}
// tidyWhitelistIdentity is used to delete entries in the whitelist that are expired.
func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) error {
func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error {
grabbed := atomic.CompareAndSwapUint32(&b.tidyWhitelistCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(&b.tidyWhitelistCASGuard, 0)
@ -42,13 +42,13 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
bufferDuration := time.Duration(safety_buffer) * time.Second
identities, err := s.List("whitelist/identity/")
identities, err := s.List(ctx, "whitelist/identity/")
if err != nil {
return err
}
for _, instanceID := range identities {
identityEntry, err := s.Get("whitelist/identity/" + instanceID)
identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID)
if err != nil {
return fmt.Errorf("error fetching identity of instanceID %s: %s", instanceID, err)
}
@ -67,7 +67,7 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete("whitelist/identity" + instanceID); err != nil {
if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil {
return fmt.Errorf("error deleting identity of instanceID %s from storage: %s", instanceID, err)
}
}
@ -78,7 +78,7 @@ func (b *backend) tidyWhitelistIdentity(s logical.Storage, safety_buffer int) er
// pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired.
func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidyWhitelistIdentity(req.Storage, data.Get("safety_buffer").(int))
return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int))
}
const pathTidyIdentityWhitelistSyn = `

View File

@ -32,7 +32,7 @@ expiration, before it is removed from the backend storage.`,
}
// tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist.
func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) error {
func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error {
grabbed := atomic.CompareAndSwapUint32(&b.tidyBlacklistCASGuard, 0, 1)
if grabbed {
defer atomic.StoreUint32(&b.tidyBlacklistCASGuard, 0)
@ -41,13 +41,13 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
}
bufferDuration := time.Duration(safety_buffer) * time.Second
tags, err := s.List("blacklist/roletag/")
tags, err := s.List(ctx, "blacklist/roletag/")
if err != nil {
return err
}
for _, tag := range tags {
tagEntry, err := s.Get("blacklist/roletag/" + tag)
tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag)
if err != nil {
return fmt.Errorf("error fetching tag %s: %s", tag, err)
}
@ -66,7 +66,7 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
}
if time.Now().After(result.ExpirationTime.Add(bufferDuration)) {
if err := s.Delete("blacklist/roletag" + tag); err != nil {
if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil {
return fmt.Errorf("error deleting tag %s from storage: %s", tag, err)
}
}
@ -77,7 +77,7 @@ func (b *backend) tidyBlacklistRoleTag(s logical.Storage, safety_buffer int) err
// pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist.
func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
return nil, b.tidyBlacklistRoleTag(req.Storage, data.Get("safety_buffer").(int))
return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int))
}
const pathTidyRoletagBlacklistSyn = `

View File

@ -1,6 +1,7 @@
package cert
import (
"context"
"strings"
"sync"
@ -8,9 +9,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -50,7 +51,7 @@ type backend struct {
crlUpdateMutex *sync.RWMutex
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch {
case strings.HasPrefix(key, "crls/"):
b.crlUpdateMutex.Lock()

View File

@ -306,7 +306,7 @@ func TestBackend_NonCAExpiry(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -366,7 +366,7 @@ func TestBackend_RegisteredNonCA_CRL(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -449,7 +449,7 @@ func TestBackend_CRLs(t *testing.T) {
storage := &logical.InmemStorage{}
config.StorageView = storage
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -586,7 +586,7 @@ func TestBackend_CRLs(t *testing.T) {
}
func testFactory(t *testing.T) logical.Backend {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 1000 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,
@ -1135,7 +1135,7 @@ func testConnState(certPath, keyPath, rootCertPath string) (tls.ConnectionState,
func Test_Renew(t *testing.T) {
storage := &logical.InmemStorage{}
lb, err := Factory(&logical.BackendConfig{
lb, err := Factory(context.Background(), &logical.BackendConfig{
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: 300 * time.Second,
MaxLeaseTTLVal: 1800 * time.Second,
@ -1195,6 +1195,7 @@ func Test_Renew(t *testing.T) {
req.Auth.LeaseOptions = resp.Auth.LeaseOptions
req.Auth.Policies = resp.Auth.Policies
req.Auth.IssueTime = time.Now()
req.Auth.Period = resp.Auth.Period
// Normal renewal
resp, err = b.pathLoginRenew(context.Background(), req, empty_login_fd)
@ -1238,6 +1239,29 @@ func Test_Renew(t *testing.T) {
t.Fatalf("got error: %#v", *resp)
}
// Add period value to cert entry
period := 350 * time.Second
fd.Raw["period"] = period.String()
resp, err = b.pathCertWrite(context.Background(), req, fd)
if err != nil {
t.Fatal(err)
}
resp, err = b.pathLoginRenew(context.Background(), req, empty_login_fd)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("got nil response from renew")
}
if resp.IsError() {
t.Fatalf("got error: %#v", *resp)
}
if resp.Auth.Period != period {
t.Fatalf("expected a period value of %s in the response, got: %s", period, resp.Auth.Period)
}
// Delete CA, make sure we can't renew
resp, err = b.pathCertDelete(context.Background(), req, fd)
if err != nil {

View File

@ -40,17 +40,22 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
func (h *CLIHandler) Help() string {
help := `
The "cert" credential provider allows you to authenticate with a
client certificate. No other authentication materials are needed.
Optionally, you may specify the specific certificate role to
authenticate against with the "name" parameter.
Usage: vault login -method=cert [CONFIG K=V...]
Example: vault auth -method=cert \
-client-cert=/path/to/cert.pem \
-client-key=/path/to/key.pem
name=cert1
The certificate auth method allows uers to authenticate with a
client certificate passed with the request. The -client-cert and -client-key
flags are included with the "vault login" command, NOT as configuration to the
auth method.
`
Authenticate using a local client certificate:
$ vault login -method=cert -client-cert=cert.pem -client-key=key.pem
Configuration:
name=<string>
Certificate role to authenticate against.
`
return strings.TrimSpace(help)
}

View File

@ -101,8 +101,8 @@ TTL will be set to the value of this parameter.`,
}
}
func (b *backend) Cert(s logical.Storage, n string) (*CertEntry, error) {
entry, err := s.Get("cert/" + strings.ToLower(n))
func (b *backend) Cert(ctx context.Context, s logical.Storage, n string) (*CertEntry, error) {
entry, err := s.Get(ctx, "cert/"+strings.ToLower(n))
if err != nil {
return nil, err
}
@ -118,7 +118,7 @@ func (b *backend) Cert(s logical.Storage, n string) (*CertEntry, error) {
}
func (b *backend) pathCertDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("cert/" + strings.ToLower(d.Get("name").(string)))
err := req.Storage.Delete(ctx, "cert/"+strings.ToLower(d.Get("name").(string)))
if err != nil {
return nil, err
}
@ -126,7 +126,7 @@ func (b *backend) pathCertDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathCertList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
certs, err := req.Storage.List("cert/")
certs, err := req.Storage.List(ctx, "cert/")
if err != nil {
return nil, err
}
@ -134,7 +134,7 @@ func (b *backend) pathCertList(ctx context.Context, req *logical.Request, d *fra
}
func (b *backend) pathCertRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cert, err := b.Cert(req.Storage, strings.ToLower(d.Get("name").(string)))
cert, err := b.Cert(ctx, req.Storage, strings.ToLower(d.Get("name").(string)))
if err != nil {
return nil, err
}
@ -144,12 +144,13 @@ func (b *backend) pathCertRead(ctx context.Context, req *logical.Request, d *fra
return &logical.Response{
Data: map[string]interface{}{
"certificate": cert.Certificate,
"display_name": cert.DisplayName,
"policies": cert.Policies,
"ttl": cert.TTL / time.Second,
"max_ttl": cert.MaxTTL / time.Second,
"period": cert.Period / time.Second,
"certificate": cert.Certificate,
"display_name": cert.DisplayName,
"policies": cert.Policies,
"ttl": cert.TTL / time.Second,
"max_ttl": cert.MaxTTL / time.Second,
"period": cert.Period / time.Second,
"allowed_names": cert.AllowedNames,
},
}, nil
}
@ -244,7 +245,7 @@ func (b *backend) pathCertWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -35,15 +35,15 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
return nil, nil
}
// Config returns the configuration for this backend.
func (b *backend) Config(s logical.Storage) (*config, error) {
entry, err := s.Get("config")
func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
}

View File

@ -42,7 +42,7 @@ using the same name as specified here.`,
}
}
func (b *backend) populateCRLs(storage logical.Storage) error {
func (b *backend) populateCRLs(ctx context.Context, storage logical.Storage) error {
b.crlUpdateMutex.Lock()
defer b.crlUpdateMutex.Unlock()
@ -52,7 +52,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
b.crls = map[string]CRLInfo{}
keys, err := storage.List("crls/")
keys, err := storage.List(ctx, "crls/")
if err != nil {
return fmt.Errorf("error listing CRLs: %v", err)
}
@ -61,7 +61,7 @@ func (b *backend) populateCRLs(storage logical.Storage) error {
}
for _, key := range keys {
entry, err := storage.Get("crls/" + key)
entry, err := storage.Get(ctx, "crls/"+key)
if err != nil {
b.crls = nil
return fmt.Errorf("error loading CRL %s: %v", key, err)
@ -129,7 +129,7 @@ func (b *backend) pathCRLDelete(ctx context.Context, req *logical.Request, d *fr
return logical.ErrorResponse(`"name" parameter cannot be empty`), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
return nil, err
}
@ -143,7 +143,7 @@ func (b *backend) pathCRLDelete(ctx context.Context, req *logical.Request, d *fr
)), nil
}
if err := req.Storage.Delete("crls/" + name); err != nil {
if err := req.Storage.Delete(ctx, "crls/"+name); err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"error deleting crl %s: %v", name, err),
), nil
@ -160,7 +160,7 @@ func (b *backend) pathCRLRead(ctx context.Context, req *logical.Request, d *fram
return logical.ErrorResponse(`"name" parameter must be set`), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
return nil, err
}
@ -198,7 +198,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
return logical.ErrorResponse("parsed CRL is nil"), nil
}
if err := b.populateCRLs(req.Storage); err != nil {
if err := b.populateCRLs(ctx, req.Storage); err != nil {
return nil, err
}
@ -216,7 +216,7 @@ func (b *backend) pathCRLWrite(ctx context.Context, req *logical.Request, d *fra
if err != nil {
return nil, err
}
if err = req.Storage.Put(entry); err != nil {
if err = req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -60,7 +60,7 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
var matched *ParsedCert
if verifyResp, resp, err := b.verifyCredentials(req, data); err != nil {
if verifyResp, resp, err := b.verifyCredentials(ctx, req, data); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@ -128,14 +128,14 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
}
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
if !config.DisableBinding {
var matched *ParsedCert
if verifyResp, resp, err := b.verifyCredentials(req, d); err != nil {
if verifyResp, resp, err := b.verifyCredentials(ctx, req, d); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@ -162,7 +162,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
}
// Get the cert and use its TTL
cert, err := b.Cert(req.Storage, req.Auth.Metadata["cert_name"])
cert, err := b.Cert(ctx, req.Storage, req.Auth.Metadata["cert_name"])
if err != nil {
return nil, err
}
@ -175,15 +175,20 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return nil, fmt.Errorf("policies have changed, not renewing")
}
resp, err := framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(ctx, req, d)
if err != nil {
return nil, err
// If a period is provided, set that as part of resp.Auth.Period and return a
// response immediately. Let expiration manager handle renewal from there on.
if cert.Period > time.Duration(0) {
resp := &logical.Response{
Auth: req.Auth,
}
resp.Auth.Period = cert.Period
return resp, nil
}
resp.Auth.Period = cert.Period
return resp, nil
return framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(ctx, req, d)
}
func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) {
// Get the connection state
if req.Connection == nil || req.Connection.ConnState == nil {
return nil, logical.ErrorResponse("tls connection required"), nil
@ -204,9 +209,10 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
}
// Load the trusted certificates
roots, trusted, trustedNonCAs := b.loadTrustedCerts(req.Storage, certName)
roots, trusted, trustedNonCAs := b.loadTrustedCerts(ctx, req.Storage, certName)
// Get the list of full chains matching the connection
// Get the list of full chains matching the connection and validates the
// certificate itself
trustedChains, err := validateConnState(roots, connState)
if err != nil {
return nil, nil, err
@ -227,6 +233,7 @@ func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData
}
// If no trusted chain was found, client is not authenticated
// This check happens after checking for a matching configured non-CA certs
if len(trustedChains) == 0 {
return nil, logical.ErrorResponse("invalid certificate or no client certificate supplied"), nil
}
@ -324,11 +331,11 @@ func (b *backend) matchesCertificateExtenions(clientCert *x509.Certificate, conf
}
// loadTrustedCerts is used to load all the trusted certificates from the backend
func (b *backend) loadTrustedCerts(store logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
func (b *backend) loadTrustedCerts(ctx context.Context, storage logical.Storage, certName string) (pool *x509.CertPool, trusted []*ParsedCert, trustedNonCAs []*ParsedCert) {
pool = x509.NewCertPool()
trusted = make([]*ParsedCert, 0)
trustedNonCAs = make([]*ParsedCert, 0)
names, err := store.List("cert/")
names, err := storage.List(ctx, "cert/")
if err != nil {
b.Logger().Error("cert: failed to list trusted certs", "error", err)
return
@ -338,7 +345,7 @@ func (b *backend) loadTrustedCerts(store logical.Storage, certName string) (pool
if certName != "" && name != certName {
continue
}
entry, err := b.Cert(store, strings.TrimPrefix(name, "cert/"))
entry, err := b.Cert(ctx, storage, strings.TrimPrefix(name, "cert/"))
if err != nil {
b.Logger().Error("cert: failed to load trusted cert", "name", name, "error", err)
continue
@ -415,17 +422,17 @@ func parsePEM(raw []byte) (certs []*x509.Certificate) {
// verification logic here: http://golang.org/src/crypto/tls/handshake_server.go
// The trusted chains are returned.
func validateConnState(roots *x509.CertPool, cs *tls.ConnectionState) ([][]*x509.Certificate, error) {
certs := cs.PeerCertificates
if len(certs) == 0 {
return nil, nil
}
opts := x509.VerifyOptions{
Roots: roots,
Intermediates: x509.NewCertPool(),
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
certs := cs.PeerCertificates
if len(certs) == 0 {
return nil, nil
}
if len(certs) > 1 {
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)

View File

@ -11,9 +11,9 @@ import (
"golang.org/x/oauth2"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@ -1,6 +1,7 @@
package github
import (
"context"
"fmt"
"os"
"strings"
@ -14,7 +15,7 @@ import (
func TestBackend_Config(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 2
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
@ -92,7 +93,7 @@ func testConfigWrite(t *testing.T, d map[string]interface{}) logicaltest.TestSte
func TestBackend_basic(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,

View File

@ -2,13 +2,18 @@ package github
import (
"fmt"
"io"
"os"
"strings"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/password"
)
type CLIHandler struct{}
type CLIHandler struct {
// for tests
testStdout io.Writer
}
func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, error) {
mount, ok := m["mount"]
@ -16,16 +21,39 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
mount = "github"
}
token, ok := m["token"]
if !ok {
if token = os.Getenv("VAULT_AUTH_GITHUB_TOKEN"); token == "" {
return nil, fmt.Errorf("GitHub token should be provided either as 'value' for 'token' key,\nor via an env var VAULT_AUTH_GITHUB_TOKEN")
// Extract or prompt for token
token := m["token"]
if token == "" {
token = os.Getenv("VAULT_AUTH_GITHUB_TOKEN")
}
if token == "" {
// Override the output
stdout := h.testStdout
if stdout == nil {
stdout = os.Stderr
}
var err error
fmt.Fprintf(stdout, "GitHub Personal Access Token (will be hidden): ")
token, err = password.Read(os.Stdin)
fmt.Fprintf(stdout, "\n")
if err != nil {
if err == password.ErrInterrupted {
return nil, fmt.Errorf("user interrupted")
}
return nil, fmt.Errorf("An error occurred attempting to "+
"ask for a token. The raw error message is shown below, but usually "+
"this is because you attempted to pipe a value into the command or "+
"you are executing outside of a terminal (tty). If you want to pipe "+
"the value, pass \"-\" as the argument to read from stdin. The raw "+
"error was: %s", err)
}
}
path := fmt.Sprintf("auth/%s/login", mount)
secret, err := c.Logical().Write(path, map[string]interface{}{
"token": token,
"token": strings.TrimSpace(token),
})
if err != nil {
return nil, err
@ -39,20 +67,28 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
func (h *CLIHandler) Help() string {
help := `
The GitHub credential provider allows you to authenticate with GitHub.
To use it, specify the "token" parameter. The value should be a personal access
token for your GitHub account. You can generate a personal access token on your
account settings page on GitHub.
Usage: vault login -method=github [CONFIG K=V...]
Example: vault auth -method=github token=<token>
The GitHub auth method allows users to authenticate using a GitHub
personal access token. Users can generate a personal access token from the
settings page on their GitHub account.
Key/Value Pairs:
Authenticate using a GitHub token:
mount=github The mountpoint for the GitHub credential provider.
Defaults to "github"
$ vault login -method=github token=abcd1234
token=<token> The GitHub personal access token for authentication.
`
Configuration:
mount=<string>
Path where the GitHub credential method is mounted. This is usually
provided via the -path flag in the "vault login" command, but it can be
specified here as well. If specified here, it takes precedence over the
value for -path. The default value is "github".
token=<string>
GitHub personal access token to use for authentication. If not provided,
Vault will prompt for the value.
`
return strings.TrimSpace(help)
}

View File

@ -6,7 +6,6 @@ import (
"net/url"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -87,7 +86,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -95,7 +94,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, dat
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -108,14 +107,19 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, data
config.MaxTTL /= time.Second
resp := &logical.Response{
Data: structs.New(config).Map(),
Data: map[string]interface{}{
"organization": config.Organization,
"base_url": config.BaseURL,
"ttl": config.TTL,
"max_ttl": config.MaxTTL,
},
}
return resp, nil
}
// Config returns the configuration for this backend.
func (b *backend) Config(s logical.Storage) (*config, error) {
entry, err := s.Get("config")
func (b *backend) Config(ctx context.Context, s logical.Storage) (*config, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
}

View File

@ -33,7 +33,7 @@ func (b *backend) pathLoginAliasLookahead(ctx context.Context, req *logical.Requ
token := data.Get("token").(string)
var verifyResp *verifyCredentialsResp
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@ -54,7 +54,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
token := data.Get("token").(string)
var verifyResp *verifyCredentialsResp
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@ -62,7 +62,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, data *fra
verifyResp = verifyResponse
}
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -117,7 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
token := tokenRaw.(string)
var verifyResp *verifyCredentialsResp
if verifyResponse, resp, err := b.verifyCredentials(req, token); err != nil {
if verifyResponse, resp, err := b.verifyCredentials(ctx, req, token); err != nil {
return nil, err
} else if resp != nil {
return resp, nil
@ -128,7 +128,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return nil, fmt.Errorf("policies do not match")
}
config, err := b.Config(req.Storage)
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -150,8 +150,8 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return resp, nil
}
func (b *backend) verifyCredentials(req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
config, err := b.Config(req.Storage)
func (b *backend) verifyCredentials(ctx context.Context, req *logical.Request, token string) (*verifyCredentialsResp, *logical.Response, error) {
config, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, nil, err
}
@ -174,7 +174,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
}
// Get the user
user, _, err := client.Users.Get(context.Background(), "")
user, _, err := client.Users.Get(ctx, "")
if err != nil {
return nil, nil, err
}
@ -188,7 +188,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
var allOrgs []*github.Organization
for {
orgs, resp, err := client.Organizations.List(context.Background(), "", orgOpt)
orgs, resp, err := client.Organizations.List(ctx, "", orgOpt)
if err != nil {
return nil, nil, err
}
@ -218,7 +218,7 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
var allTeams []*github.Team
for {
teams, resp, err := client.Organizations.ListUserTeams(context.Background(), teamOpt)
teams, resp, err := client.Organizations.ListUserTeams(ctx, teamOpt)
if err != nil {
return nil, nil, err
}
@ -242,13 +242,13 @@ func (b *backend) verifyCredentials(req *logical.Request, token string) (*verify
}
}
groupPoliciesList, err := b.TeamMap.Policies(req.Storage, teamNames...)
groupPoliciesList, err := b.TeamMap.Policies(ctx, req.Storage, teamNames...)
if err != nil {
return nil, nil, err
}
userPoliciesList, err := b.UserMap.Policies(req.Storage, []string{*user.Login}...)
userPoliciesList, err := b.UserMap.Policies(ctx, req.Storage, []string{*user.Login}...)
if err != nil {
return nil, nil, err

View File

@ -2,6 +2,7 @@ package ldap
import (
"bytes"
"context"
"fmt"
"text/template"
@ -12,9 +13,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -92,9 +93,9 @@ func EscapeLDAPValue(input string) string {
return input
}
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, nil, nil, err
}
@ -172,7 +173,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
var allGroups []string
// Import the custom added groups from ldap backend
user, err := b.User(req.Storage, username)
user, err := b.User(ctx, req.Storage, username)
if err == nil && user != nil && user.Groups != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/ldap: adding local groups", "num_local_groups", len(user.Groups), "local_groups", user.Groups)
@ -185,7 +186,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
// Retrieve policies
var policies []string
for _, groupName := range allGroups {
group, err := b.Group(req.Storage, groupName)
group, err := b.Group(ctx, req.Storage, groupName)
if err == nil && group != nil {
policies = append(policies, group.Policies...)
}

View File

@ -23,7 +23,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) {
t.Fatalf("failed to create backend")
}
err := b.Backend.Setup(config)
err := b.Backend.Setup(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -103,7 +103,7 @@ func TestLdapAuthBackend_UserPolicies(t *testing.T) {
}
/*
* Acceptance test for LDAP Auth Backend
* Acceptance test for LDAP Auth Method
*
* The tests here rely on a public LDAP server:
* [http://www.forumsys.com/tutorials/integration-how-to/ldap/online-ldap-test-server/]
@ -120,7 +120,7 @@ func TestLdapAuthBackend_UserPolicies(t *testing.T) {
func factory(t *testing.T) logical.Backend {
defaultLeaseTTLVal := time.Hour * 24
maxLeaseTTLVal := time.Hour * 24 * 32
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,

View File

@ -26,10 +26,10 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
}
password, ok := m["password"]
if !ok {
fmt.Printf("Password (will be hidden): ")
fmt.Fprintf(os.Stderr, "Password (will be hidden): ")
var err error
password, err = pwd.Read(os.Stdin)
fmt.Println()
fmt.Fprintf(os.Stderr, "\n")
if err != nil {
return nil, err
}
@ -62,18 +62,40 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
func (h *CLIHandler) Help() string {
help := `
The LDAP credential provider allows you to authenticate with LDAP.
To use it, first configure it through the "config" endpoint, and then
login by specifying username and password. If password is not provided
on the command line, it will be read from stdin.
Usage: vault login -method=ldap [CONFIG K=V...]
If multi-factor authentication (MFA) is enabled, a "method" and/or "passcode"
may be provided depending on the MFA backend enabled. To check
which MFA backend is in use, read "auth/[mount]/mfa_config".
The LDAP auth method allows users to authenticate using LDAP or
Active Directory.
Example: vault auth -method=ldap username=john
If MFA is enabled, a "method" and/or "passcode" may be required depending on
the MFA method. To check which MFA is in use, run:
`
$ vault read auth/<mount>/mfa_config
Authenticate as "sally":
$ vault login -method=ldap username=sally
Password (will be hidden):
Authenticate as "bob":
$ vault login -method=ldap username=bob password=password
Configuration:
method=<string>
MFA method.
passcode=<string>
MFA OTP/passcode.
password=<string>
LDAP password to use for authentication. If not provided, the CLI will
prompt for this on stdin.
username=<string>
LDAP username to use for authentication.
`
return strings.TrimSpace(help)
}

View File

@ -130,7 +130,7 @@ Default: cn`,
/*
* Construct ConfigEntry struct using stored configuration.
*/
func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
// Schema for ConfigEntry
fd, err := b.getConfigFieldData()
if err != nil {
@ -143,7 +143,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
return nil, err
}
storedConfig, err := req.Storage.Get("config")
storedConfig, err := req.Storage.Get(ctx, "config")
if err != nil {
return nil, err
}
@ -165,7 +165,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, err
}
@ -299,7 +299,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -47,8 +47,8 @@ func pathGroups(b *backend) *framework.Path {
}
}
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
entry, err := s.Get("group/" + n)
func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*GroupEntry, error) {
entry, err := s.Get(ctx, "group/"+n)
if err != nil {
return nil, err
}
@ -65,7 +65,7 @@ func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, error) {
}
func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("group/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "group/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@ -74,7 +74,7 @@ func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *
}
func (b *backend) pathGroupRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
group, err := b.Group(req.Storage, d.Get("name").(string))
group, err := b.Group(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
@ -97,7 +97,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -105,7 +105,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List("group/")
groups, err := req.Storage.List(ctx, "group/")
if err != nil {
return nil, err
}

View File

@ -54,7 +54,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
username := d.Get("username").(string)
password := d.Get("password").(string)
policies, resp, groupNames, err := b.Login(req, username, password)
policies, resp, groupNames, err := b.Login(ctx, req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@ -102,7 +102,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}

View File

@ -54,8 +54,8 @@ func pathUsers(b *backend) *framework.Path {
}
}
func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get("user/" + n)
func (b *backend) User(ctx context.Context, s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get(ctx, "user/"+n)
if err != nil {
return nil, err
}
@ -72,7 +72,7 @@ func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
}
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("user/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "user/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@ -81,7 +81,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
user, err := b.User(req.Storage, d.Get("name").(string))
user, err := b.User(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
@ -113,7 +113,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -121,7 +121,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package okta
import (
"context"
"fmt"
"github.com/chrismalek/oktasdk-go/okta"
@ -9,9 +10,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -54,13 +55,13 @@ type backend struct {
*framework.Backend
}
func (b *backend) Login(req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(req.Storage)
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, nil, nil, err
}
if cfg == nil {
return nil, logical.ErrorResponse("Okta backend not configured"), nil, nil
return nil, logical.ErrorResponse("Okta auth method not configured"), nil, nil
}
client := cfg.OktaClient()
@ -71,6 +72,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
type authResult struct {
Embedded embeddedResult `json:"_embedded"`
Status string `json:"status"`
}
authReq, err := client.NewRequest("POST", "authn", map[string]interface{}{
@ -87,13 +89,50 @@ func (b *backend) Login(req *logical.Request, username string, password string)
return nil, logical.ErrorResponse(fmt.Sprintf("Okta auth failed: %v", err)), nil, nil
}
if rsp == nil {
return nil, logical.ErrorResponse("okta auth backend unexpected failure"), nil, nil
return nil, logical.ErrorResponse("okta auth method unexpected failure"), nil, nil
}
oktaResponse := &logical.Response{
Data: map[string]interface{}{},
}
// If lockout failures are not configured to be hidden, the status needs to
// be inspected for LOCKED_OUT status. Otherwise, it is handled above by an
// error returned during the authentication request.
switch result.Status {
case "LOCKED_OUT":
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: user is locked out", "user", username)
}
return nil, logical.ErrorResponse("okta authentication failed"), nil, nil
case "PASSWORD_EXPIRED":
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: password is expired", "user", username)
}
return nil, logical.ErrorResponse("okta authentication failed"), nil, nil
case "PASSWORD_WARN":
oktaResponse.AddWarning("Your Okta password is in warning state and needs to be changed soon.")
case "SUCCESS":
// Do nothing here
default:
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: unhandled result status", "status", result.Status)
}
return nil, logical.ErrorResponse("okta authentication failed"), nil, nil
}
// Verify result status again in case a switch case above modifies result
if result.Status != "SUCCESS" && result.Status != "PASSWORD_WARN" {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: authentication returned a non-success status", "status", result.Status)
}
return nil, logical.ErrorResponse("okta authentication failed"), nil, nil
}
var allGroups []string
// Only query the Okta API for group membership if we have a token
if cfg.Token != "" {
@ -110,7 +149,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
}
// Import the custom added groups from okta backend
user, err := b.User(req.Storage, username)
user, err := b.User(ctx, req.Storage, username)
if err != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: error looking up user", "error", err)
@ -126,7 +165,7 @@ func (b *backend) Login(req *logical.Request, username string, password string)
// Retrieve policies
var policies []string
for _, groupName := range allGroups {
entry, _, err := b.Group(req.Storage, groupName)
entry, _, err := b.Group(ctx, req.Storage, groupName)
if err != nil {
if b.Logger().IsDebug() {
b.Logger().Debug("auth/okta: error looking up group policies", "error", err)
@ -161,7 +200,7 @@ func (b *backend) getOktaGroups(client *okta.Client, user *okta.User) ([]string,
return nil, err
}
if rsp == nil {
return nil, fmt.Errorf("okta auth backend unexpected failure")
return nil, fmt.Errorf("okta auth method unexpected failure")
}
oktaGroups := make([]string, 0, len(user.Groups))
for _, group := range user.Groups {

View File

@ -1,6 +1,7 @@
package okta
import (
"context"
"fmt"
"os"
"strings"
@ -19,7 +20,7 @@ import (
func TestBackend_Config(t *testing.T) {
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: logformat.NewVaultLogger(log.LevelTrace),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
@ -53,16 +54,16 @@ func TestBackend_Config(t *testing.T) {
testConfigCreate(t, configData),
testLoginWrite(t, username, "wrong", "E0000004", 0, nil),
testLoginWrite(t, username, password, "user is not a member of any authorized policy", 0, nil),
testAccUserGroups(t, username, "local_grouP,lOcal_group2"),
testAccUserGroups(t, username, "local_grouP,lOcal_group2", []string{"user_policy"}),
testAccGroups(t, "local_groUp", "loCal_group_policy"),
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy", "user_policy"}),
testAccGroups(t, "everyoNe", "everyone_grouP_policy,eveRy_group_policy2"),
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy"}),
testLoginWrite(t, username, password, "", defaultLeaseTTLVal, []string{"local_group_policy", "user_policy"}),
testConfigUpdate(t, configDataToken),
testConfigRead(t, token, configData),
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy"}),
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "user_policy"}),
testAccGroups(t, "locAl_group2", "testgroup_group_policy"),
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy"}),
testLoginWrite(t, username, password, "", updatedDuration, []string{"everyone_group_policy", "every_group_policy2", "local_group_policy", "testgroup_group_policy", "user_policy"}),
},
})
}
@ -154,19 +155,24 @@ func testAccPreCheck(t *testing.T) {
if v := os.Getenv("OKTA_ORG"); v == "" {
t.Fatal("OKTA_ORG must be set for acceptance tests")
}
if v := os.Getenv("OKTA_API_TOKEN"); v == "" {
t.Fatal("OKTA_API_TOKEN must be set for acceptance tests")
}
}
func testAccUserGroups(t *testing.T, user string, groups string) logicaltest.TestStep {
func testAccUserGroups(t *testing.T, user string, groups interface{}, policies interface{}) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.UpdateOperation,
Path: "users/" + user,
Data: map[string]interface{}{
"groups": groups,
"groups": groups,
"policies": policies,
},
}
}
func testAccGroups(t *testing.T, group string, policies string) logicaltest.TestStep {
func testAccGroups(t *testing.T, group string, policies interface{}) logicaltest.TestStep {
t.Logf("[testAccGroups] - Registering group %s, policy %s", group, policies)
return logicaltest.TestStep{
Operation: logical.UpdateOperation,

View File

@ -25,10 +25,10 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
}
password, ok := m["password"]
if !ok {
fmt.Printf("Password (will be hidden): ")
fmt.Fprintf(os.Stderr, "Password (will be hidden): ")
var err error
password, err = pwd.Read(os.Stdin)
fmt.Println()
fmt.Fprintf(os.Stderr, "\n")
if err != nil {
return nil, err
}
@ -62,14 +62,28 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
// Help method for okta cli
func (h *CLIHandler) Help() string {
help := `
The Okta credential provider allows you to authenticate with Okta.
To use it, first configure it through the "config" endpoint, and then
login by specifying username and password. If password is not provided
on the command line, it will be read from stdin.
Usage: vault login -method=okta [CONFIG K=V...]
Example: vault auth -method=okta username=john
The Okta auth method allows users to authenticate using Okta.
`
Authenticate as "sally":
$ vault login -method=okta username=sally
Password (will be hidden):
Authenticate as "bob":
$ vault login -method=okta username=bob password=password
Configuration:
password=<string>
Okta password to use for authentication. If not provided, the CLI will
prompt for this on stdin.
username=<string>
Okta username to use for authentication.
`
return strings.TrimSpace(help)
}

View File

@ -69,8 +69,8 @@ func pathConfig(b *backend) *framework.Path {
}
// Config returns the configuration for this backend.
func (b *backend) Config(s logical.Storage) (*ConfigEntry, error) {
entry, err := s.Get("config")
func (b *backend) Config(ctx context.Context, s logical.Storage) (*ConfigEntry, error) {
entry, err := s.Get(ctx, "config")
if err != nil {
return nil, err
}
@ -89,7 +89,7 @@ func (b *backend) Config(s logical.Storage) (*ConfigEntry, error) {
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -116,7 +116,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -193,7 +193,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
if err != nil {
return nil, err
}
if err := req.Storage.Put(jsonCfg); err != nil {
if err := req.Storage.Put(ctx, jsonCfg); err != nil {
return nil, err
}
@ -201,7 +201,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
}
func (b *backend) pathConfigExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return false, err
}

View File

@ -50,20 +50,20 @@ func pathGroups(b *backend) *framework.Path {
// We look up groups in a case-insensitive manner since Okta is case-preserving
// but case-insensitive for comparisons
func (b *backend) Group(s logical.Storage, n string) (*GroupEntry, string, error) {
func (b *backend) Group(ctx context.Context, s logical.Storage, n string) (*GroupEntry, string, error) {
canonicalName := n
entry, err := s.Get("group/" + n)
entry, err := s.Get(ctx, "group/"+n)
if err != nil {
return nil, "", err
}
if entry == nil {
entries, err := s.List("group/")
entries, err := s.List(ctx, "group/")
if err != nil {
return nil, "", err
}
for _, groupName := range entries {
if strings.ToLower(groupName) == strings.ToLower(n) {
entry, err = s.Get("group/" + groupName)
entry, err = s.Get(ctx, "group/"+groupName)
if err != nil {
return nil, "", err
}
@ -90,12 +90,12 @@ func (b *backend) pathGroupDelete(ctx context.Context, req *logical.Request, d *
return logical.ErrorResponse("'name' must be supplied"), nil
}
entry, canonicalName, err := b.Group(req.Storage, name)
entry, canonicalName, err := b.Group(ctx, req.Storage, name)
if err != nil {
return nil, err
}
if entry != nil {
err := req.Storage.Delete("group/" + canonicalName)
err := req.Storage.Delete(ctx, "group/"+canonicalName)
if err != nil {
return nil, err
}
@ -110,7 +110,7 @@ func (b *backend) pathGroupRead(ctx context.Context, req *logical.Request, d *fr
return logical.ErrorResponse("'name' must be supplied"), nil
}
group, _, err := b.Group(req.Storage, name)
group, _, err := b.Group(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@ -133,7 +133,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
// Check for an existing group, possibly lowercased so that we keep using
// existing user set values
_, canonicalName, err := b.Group(req.Storage, name)
_, canonicalName, err := b.Group(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@ -149,7 +149,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -157,7 +157,7 @@ func (b *backend) pathGroupWrite(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathGroupList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
groups, err := req.Storage.List("group/")
groups, err := req.Storage.List(ctx, "group/")
if err != nil {
return nil, err
}

View File

@ -56,7 +56,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
username := d.Get("username").(string)
password := d.Get("password").(string)
policies, resp, groupNames, err := b.Login(req, username, password)
policies, resp, groupNames, err := b.Login(ctx, req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@ -72,7 +72,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
sort.Strings(policies)
cfg, err := b.getConfig(req)
cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
}
@ -112,7 +112,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
username := req.Auth.Metadata["username"]
password := req.Auth.InternalData["password"].(string)
loginPolicies, resp, groupNames, err := b.Login(req, username, password)
loginPolicies, resp, groupNames, err := b.Login(ctx, req, username, password)
if len(loginPolicies) == 0 {
return resp, err
}
@ -121,7 +121,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return nil, fmt.Errorf("policies have changed, not renewing")
}
cfg, err := b.getConfig(req)
cfg, err := b.getConfig(ctx, req)
if err != nil {
return nil, err
}
@ -144,9 +144,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) getConfig(req *logical.Request) (*ConfigEntry, error) {
func (b *backend) getConfig(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
cfg, err := b.Config(req.Storage)
cfg, err := b.Config(ctx, req.Storage)
if err != nil {
return nil, err
}

View File

@ -2,7 +2,6 @@ package okta
import (
"context"
"strings"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@ -31,13 +30,13 @@ func pathUsers(b *backend) *framework.Path {
},
"groups": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Comma-separated list of groups associated with the user.",
Type: framework.TypeCommaStringSlice,
Description: "List of groups associated with the user.",
},
"policies": &framework.FieldSchema{
Type: framework.TypeString,
Description: "Comma-separated list of policies associated with the user.",
Type: framework.TypeCommaStringSlice,
Description: "List of policies associated with the user.",
},
},
@ -52,8 +51,8 @@ func pathUsers(b *backend) *framework.Path {
}
}
func (b *backend) User(s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get("user/" + n)
func (b *backend) User(ctx context.Context, s logical.Storage, n string) (*UserEntry, error) {
entry, err := s.Get(ctx, "user/"+n)
if err != nil {
return nil, err
}
@ -75,7 +74,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
return logical.ErrorResponse("Error empty name"), nil
}
err := req.Storage.Delete("user/" + name)
err := req.Storage.Delete(ctx, "user/"+name)
if err != nil {
return nil, err
}
@ -89,7 +88,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
return logical.ErrorResponse("Error empty name"), nil
}
user, err := b.User(req.Storage, name)
user, err := b.User(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@ -111,15 +110,8 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
return logical.ErrorResponse("Error empty name"), nil
}
groups := strings.Split(d.Get("groups").(string), ",")
for i, g := range groups {
groups[i] = strings.TrimSpace(g)
}
policies := strings.Split(d.Get("policies").(string), ",")
for i, p := range policies {
policies[i] = strings.TrimSpace(p)
}
groups := d.Get("groups").([]string)
policies := d.Get("policies").([]string)
// Store it
entry, err := logical.StorageEntryJSON("user/"+name, &UserEntry{
@ -129,7 +121,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -137,7 +129,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}

View File

@ -1,14 +1,16 @@
package radius
import (
"context"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@ -1,6 +1,7 @@
package radius
import (
"context"
"fmt"
"os"
"reflect"
@ -17,7 +18,7 @@ const (
)
func TestBackend_Config(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -70,7 +71,7 @@ func TestBackend_Config(t *testing.T) {
}
func TestBackend_users(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -98,7 +99,7 @@ func TestBackend_acceptance(t *testing.T) {
return
}
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,

View File

@ -65,7 +65,7 @@ func pathConfig(b *backend) *framework.Path {
// Establishes dichotomy of request operation between CreateOperation and UpdateOperation.
// Returning 'true' forces an UpdateOperation, CreateOperation otherwise.
func (b *backend) configExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
entry, err := b.Config(req)
entry, err := b.Config(ctx, req)
if err != nil {
return false, err
}
@ -75,9 +75,8 @@ func (b *backend) configExistenceCheck(ctx context.Context, req *logical.Request
/*
* Construct ConfigEntry struct using stored configuration.
*/
func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
storedConfig, err := req.Storage.Get("config")
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
storedConfig, err := req.Storage.Get(ctx, "config")
if err != nil {
return nil, err
}
@ -96,7 +95,7 @@ func (b *backend) Config(req *logical.Request) (*ConfigEntry, error) {
}
func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, err
}
@ -113,7 +112,7 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Build a ConfigEntry struct out of the supplied FieldData
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, err
}
@ -156,7 +155,7 @@ func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Reque
policies = strings.Split(unregisteredUserPoliciesStr, ",")
for _, policy := range policies {
if policy == "root" {
return logical.ErrorResponse("root policy cannot be granted by an authentication backend"), nil
return logical.ErrorResponse("root policy cannot be granted by an auth method"), nil
}
}
}
@ -190,7 +189,7 @@ func (b *backend) pathConfigCreateUpdate(ctx context.Context, req *logical.Reque
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -76,7 +76,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
return logical.ErrorResponse("password cannot be empty"), nil
}
policies, resp, err := b.RadiusLogin(req, username, password)
policies, resp, err := b.RadiusLogin(ctx, req, username, password)
// Handle an internal error
if err != nil {
return nil, err
@ -117,7 +117,7 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
var resp *logical.Response
var loginPolicies []string
loginPolicies, resp, err = b.RadiusLogin(req, username, password)
loginPolicies, resp, err = b.RadiusLogin(ctx, req, username, password)
if err != nil || (resp != nil && resp.IsError()) {
return resp, err
}
@ -129,9 +129,9 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f
return framework.LeaseExtend(0, 0, b.System())(ctx, req, d)
}
func (b *backend) RadiusLogin(req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
func (b *backend) RadiusLogin(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, error) {
cfg, err := b.Config(req)
cfg, err := b.Config(ctx, req)
if err != nil {
return nil, nil, err
}
@ -163,7 +163,7 @@ func (b *backend) RadiusLogin(req *logical.Request, username string, password st
var policies []string
// Retrieve user entry from storage
user, err := b.user(req.Storage, username)
user, err := b.user(ctx, req.Storage, username)
if err != nil {
return policies, logical.ErrorResponse("could not retrieve user entry from storage"), err
}

View File

@ -53,7 +53,7 @@ func pathUsers(b *backend) *framework.Path {
}
func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
userEntry, err := b.user(req.Storage, data.Get("name").(string))
userEntry, err := b.user(ctx, req.Storage, data.Get("name").(string))
if err != nil {
return false, err
}
@ -61,12 +61,12 @@ func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request,
return userEntry != nil, nil
}
func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
func (b *backend) user(ctx context.Context, s logical.Storage, username string) (*UserEntry, error) {
if username == "" {
return nil, fmt.Errorf("missing username")
}
entry, err := s.Get("user/" + strings.ToLower(username))
entry, err := s.Get(ctx, "user/"+strings.ToLower(username))
if err != nil {
return nil, err
}
@ -83,7 +83,7 @@ func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
}
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("user/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "user/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@ -92,7 +92,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
user, err := b.user(req.Storage, d.Get("name").(string))
user, err := b.user(ctx, req.Storage, d.Get("name").(string))
if err != nil {
return nil, err
}
@ -112,7 +112,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
var policies = policyutil.ParsePolicies(d.Get("policies"))
for _, policy := range policies {
if policy == "root" {
return logical.ErrorResponse("root policy cannot be granted by an authentication backend"), nil
return logical.ErrorResponse("root policy cannot be granted by an auth method"), nil
}
}
@ -123,7 +123,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -131,7 +131,7 @@ func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *fr
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}

View File

@ -0,0 +1,166 @@
package token
import (
"fmt"
"io"
"os"
"strconv"
"strings"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/password"
)
type CLIHandler struct {
// for tests
testStdin io.Reader
testStdout io.Writer
}
func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, error) {
// Parse "lookup" first - we want to return an early error if the user
// supplied an invalid value here before we prompt them for a token. It would
// be annoying to type your token and then be told you supplied an invalid
// value that we could have known in advance.
lookup := true
if x, ok := m["lookup"]; ok {
parsed, err := strconv.ParseBool(x)
if err != nil {
return nil, fmt.Errorf("Failed to parse \"lookup\" as boolean: %s", err)
}
lookup = parsed
}
// Parse the token.
token, ok := m["token"]
if !ok {
// Override the output
stdout := h.testStdout
if stdout == nil {
stdout = os.Stderr
}
// No arguments given, read the token from user input
fmt.Fprintf(stdout, "Token (will be hidden): ")
var err error
token, err = password.Read(os.Stdin)
fmt.Fprintf(stdout, "\n")
if err != nil {
if err == password.ErrInterrupted {
return nil, fmt.Errorf("user interrupted")
}
return nil, fmt.Errorf("An error occurred attempting to "+
"ask for a token. The raw error message is shown below, but usually "+
"this is because you attempted to pipe a value into the command or "+
"you are executing outside of a terminal (tty). If you want to pipe "+
"the value, pass \"-\" as the argument to read from stdin. The raw "+
"error was: %s", err)
}
}
// Remove any whitespace, etc.
token = strings.TrimSpace(token)
if token == "" {
return nil, fmt.Errorf(
"A token must be passed to auth. Please view the help for more " +
"information.")
}
// If the user declined verification, return now. Note that we will not have
// a lot of information about the token.
if !lookup {
return &api.Secret{
Auth: &api.SecretAuth{
ClientToken: token,
},
}, nil
}
// If we got this far, we want to lookup and lookup the token and pull it's
// list of policies an metadata.
c.SetToken(token)
c.SetWrappingLookupFunc(func(string, string) string { return "" })
secret, err := c.Auth().Token().LookupSelf()
if err != nil {
return nil, fmt.Errorf("Error looking up token: %s", err)
}
if secret == nil {
return nil, fmt.Errorf("Empty response from lookup-self")
}
// Return an auth struct that "looks" like the response from an auth method.
// lookup and lookup-self return their data in data, not auth. We try to
// mirror that data here.
id, err := secret.TokenID()
if err != nil {
return nil, fmt.Errorf("Error accessing token ID: %s", err)
}
accessor, err := secret.TokenAccessor()
if err != nil {
return nil, fmt.Errorf("Error accessing token accessor: %s", err)
}
policies, err := secret.TokenPolicies()
if err != nil {
return nil, fmt.Errorf("Error accessing token policies: %s", err)
}
metadata, err := secret.TokenMetadata()
if err != nil {
return nil, fmt.Errorf("Error accessing token metadata: %s", err)
}
dur, err := secret.TokenTTL()
if err != nil {
return nil, fmt.Errorf("Error converting token TTL: %s", err)
}
renewable, err := secret.TokenIsRenewable()
if err != nil {
return nil, fmt.Errorf("Error checking if token is renewable: %s", err)
}
return &api.Secret{
Auth: &api.SecretAuth{
ClientToken: id,
Accessor: accessor,
Policies: policies,
Metadata: metadata,
LeaseDuration: int(dur.Seconds()),
Renewable: renewable,
},
}, nil
}
func (h *CLIHandler) Help() string {
help := `
Usage: vault login TOKEN [CONFIG K=V...]
The token auth method allows logging in directly with a token. This
can be a token from the "token-create" command or API. There are no
configuration options for this auth method.
Authenticate using a token:
$ vault login 96ddf4bc-d217-f3ba-f9bd-017055595017
Authenticate but do not lookup information about the token:
$ vault login token=96ddf4bc-d217-f3ba-f9bd-017055595017 lookup=false
This token usually comes from a different source such as the API or via the
built-in "vault token create" command.
Configuration:
token=<string>
The token to use for authentication. This is usually provided directly
via the "vault login" command.
lookup=<bool>
Perform a lookup of the token's metadata and policies.
`
return strings.TrimSpace(help)
}

View File

@ -1,14 +1,16 @@
package userpass
import (
"context"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@ -1,6 +1,7 @@
package userpass
import (
"context"
"fmt"
"reflect"
"testing"
@ -45,7 +46,7 @@ func TestBackend_TTLDurations(t *testing.T) {
data5 := map[string]interface{}{
"password": "password",
}
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -69,7 +70,7 @@ func TestBackend_TTLDurations(t *testing.T) {
}
func TestBackend_basic(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -92,7 +93,7 @@ func TestBackend_basic(t *testing.T) {
}
func TestBackend_userCrud(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -115,7 +116,7 @@ func TestBackend_userCrud(t *testing.T) {
}
func TestBackend_userCreateOperation(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -136,7 +137,7 @@ func TestBackend_userCreateOperation(t *testing.T) {
}
func TestBackend_passwordUpdate(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,
@ -161,7 +162,7 @@ func TestBackend_passwordUpdate(t *testing.T) {
}
func TestBackend_policiesUpdate(t *testing.T) {
b, err := Factory(&logical.BackendConfig{
b, err := Factory(context.Background(), &logical.BackendConfig{
Logger: nil,
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: testSysTTL,

View File

@ -30,9 +30,9 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
return nil, fmt.Errorf("'username' must be specified")
}
if data.Password == "" {
fmt.Printf("Password (will be hidden): ")
fmt.Fprintf(os.Stderr, "Password (will be hidden): ")
password, err := pwd.Read(os.Stdin)
fmt.Println()
fmt.Fprintf(os.Stderr, "\n")
if err != nil {
return nil, err
}
@ -66,20 +66,40 @@ func (h *CLIHandler) Auth(c *api.Client, m map[string]string) (*api.Secret, erro
func (h *CLIHandler) Help() string {
help := `
The "userpass"/"radius" credential provider allows you to authenticate with
a username and password. To use it, specify the "username" and "password"
parameters. If password is not provided on the command line, it will be
read from stdin.
Usage: vault login -method=userpass [CONFIG K=V...]
If multi-factor authentication (MFA) is enabled, a "method" and/or "passcode"
may be provided depending on the MFA backend enabled. To check
which MFA backend is in use, read "auth/[mount]/mfa_config".
The userpass auth method allows users to authenticate using Vault's
internal user database.
Example: vault auth -method=userpass \
username=<user> \
password=<password>
If MFA is enabled, a "method" and/or "passcode" may be required depending on
the MFA method. To check which MFA is in use, run:
`
$ vault read auth/<mount>/mfa_config
Authenticate as "sally":
$ vault login -method=userpass username=sally
Password (will be hidden):
Authenticate as "bob":
$ vault login -method=userpass username=bob password=password
Configuration:
method=<string>
MFA method.
passcode=<string>
MFA OTP/passcode.
password=<string>
Password to use for authentication. If not provided, the CLI will prompt
for this on stdin.
username=<string>
Username to use for authentication.
`
return strings.TrimSpace(help)
}

View File

@ -61,7 +61,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
}
// Get the user and validate auth
user, err := b.user(req.Storage, username)
user, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@ -102,7 +102,7 @@ func (b *backend) pathLogin(ctx context.Context, req *logical.Request, d *framew
func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
// Get the user
user, err := b.user(req.Storage, req.Auth.Metadata["username"])
user, err := b.user(ctx, req.Storage, req.Auth.Metadata["username"])
if err != nil {
return nil, err
}

View File

@ -37,7 +37,7 @@ func pathUserPassword(b *backend) *framework.Path {
func (b *backend) pathUserPasswordUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := d.Get("username").(string)
userEntry, err := b.user(req.Storage, username)
userEntry, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@ -53,7 +53,7 @@ func (b *backend) pathUserPasswordUpdate(ctx context.Context, req *logical.Reque
return logical.ErrorResponse(userErr.Error()), logical.ErrInvalidRequest
}
return nil, b.setUser(req.Storage, username, userEntry)
return nil, b.setUser(ctx, req.Storage, username, userEntry)
}
func (b *backend) updateUserPassword(req *logical.Request, d *framework.FieldData, userEntry *UserEntry) (error, error) {

View File

@ -35,7 +35,7 @@ func pathUserPolicies(b *backend) *framework.Path {
func (b *backend) pathUserPoliciesUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := d.Get("username").(string)
userEntry, err := b.user(req.Storage, username)
userEntry, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@ -45,7 +45,7 @@ func (b *backend) pathUserPoliciesUpdate(ctx context.Context, req *logical.Reque
userEntry.Policies = policyutil.ParsePolicies(d.Get("policies"))
return nil, b.setUser(req.Storage, username, userEntry)
return nil, b.setUser(ctx, req.Storage, username, userEntry)
}
const pathUserPoliciesHelpSyn = `

View File

@ -69,7 +69,7 @@ func pathUsers(b *backend) *framework.Path {
}
func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) {
userEntry, err := b.user(req.Storage, data.Get("username").(string))
userEntry, err := b.user(ctx, req.Storage, data.Get("username").(string))
if err != nil {
return false, err
}
@ -77,12 +77,12 @@ func (b *backend) userExistenceCheck(ctx context.Context, req *logical.Request,
return userEntry != nil, nil
}
func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
func (b *backend) user(ctx context.Context, s logical.Storage, username string) (*UserEntry, error) {
if username == "" {
return nil, fmt.Errorf("missing username")
}
entry, err := s.Get("user/" + strings.ToLower(username))
entry, err := s.Get(ctx, "user/"+strings.ToLower(username))
if err != nil {
return nil, err
}
@ -98,17 +98,17 @@ func (b *backend) user(s logical.Storage, username string) (*UserEntry, error) {
return &result, nil
}
func (b *backend) setUser(s logical.Storage, username string, userEntry *UserEntry) error {
func (b *backend) setUser(ctx context.Context, s logical.Storage, username string, userEntry *UserEntry) error {
entry, err := logical.StorageEntryJSON("user/"+username, userEntry)
if err != nil {
return err
}
return s.Put(entry)
return s.Put(ctx, entry)
}
func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
users, err := req.Storage.List("user/")
users, err := req.Storage.List(ctx, "user/")
if err != nil {
return nil, err
}
@ -116,7 +116,7 @@ func (b *backend) pathUserList(ctx context.Context, req *logical.Request, d *fra
}
func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("user/" + strings.ToLower(d.Get("username").(string)))
err := req.Storage.Delete(ctx, "user/"+strings.ToLower(d.Get("username").(string)))
if err != nil {
return nil, err
}
@ -125,7 +125,7 @@ func (b *backend) pathUserDelete(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
user, err := b.user(req.Storage, strings.ToLower(d.Get("username").(string)))
user, err := b.user(ctx, req.Storage, strings.ToLower(d.Get("username").(string)))
if err != nil {
return nil, err
}
@ -144,7 +144,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
username := strings.ToLower(d.Get("username").(string))
userEntry, err := b.user(req.Storage, username)
userEntry, err := b.user(ctx, req.Storage, username)
if err != nil {
return nil, err
}
@ -182,7 +182,7 @@ func (b *backend) userCreateUpdate(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse(fmt.Sprintf("err: %s", err)), nil
}
return nil, b.setUser(req.Storage, username, userEntry)
return nil, b.setUser(ctx, req.Storage, username, userEntry)
}
func (b *backend) pathUserWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {

View File

@ -1,6 +1,7 @@
package aws
import (
"context"
"strings"
"time"
@ -8,9 +9,9 @@ import (
"github.com/hashicorp/vault/logical/framework"
)
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil

View File

@ -2,6 +2,7 @@ package aws
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
@ -22,7 +23,7 @@ import (
)
func getBackend(t *testing.T) logical.Backend {
be, _ := Factory(logical.TestBackendConfig())
be, _ := Factory(context.Background(), logical.TestBackendConfig())
return be
}

View File

@ -1,6 +1,7 @@
package aws
import (
"context"
"fmt"
"os"
@ -13,11 +14,11 @@ import (
"github.com/hashicorp/vault/logical"
)
func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
func getRootConfig(ctx context.Context, s logical.Storage, clientType string) (*aws.Config, error) {
credsConfig := &awsutil.CredentialsConfig{}
var endpoint string
entry, err := s.Get("config/root")
entry, err := s.Get(ctx, "config/root")
if err != nil {
return nil, err
}
@ -63,8 +64,8 @@ func getRootConfig(s logical.Storage, clientType string) (*aws.Config, error) {
}, nil
}
func clientIAM(s logical.Storage) (*iam.IAM, error) {
awsConfig, err := getRootConfig(s, "iam")
func clientIAM(ctx context.Context, s logical.Storage) (*iam.IAM, error) {
awsConfig, err := getRootConfig(ctx, s, "iam")
if err != nil {
return nil, err
}
@ -77,8 +78,8 @@ func clientIAM(s logical.Storage) (*iam.IAM, error) {
return client, nil
}
func clientSTS(s logical.Storage) (*sts.STS, error) {
awsConfig, err := getRootConfig(s, "sts")
func clientSTS(ctx context.Context, s logical.Storage) (*sts.STS, error) {
awsConfig, err := getRootConfig(ctx, s, "sts")
if err != nil {
return nil, err
}

View File

@ -35,8 +35,8 @@ func pathConfigLease(b *backend) *framework.Path {
}
// Lease returns the lease information
func (b *backend) Lease(s logical.Storage) (*configLease, error) {
entry, err := s.Get("config/lease")
func (b *backend) Lease(ctx context.Context, s logical.Storage) (*configLease, error) {
entry, err := s.Get(ctx, "config/lease")
if err != nil {
return nil, err
}
@ -82,7 +82,7 @@ func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *f
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}
@ -90,7 +90,7 @@ func (b *backend) pathLeaseWrite(ctx context.Context, req *logical.Request, d *f
}
func (b *backend) pathLeaseRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
lease, err := b.Lease(req.Storage)
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err

View File

@ -60,7 +60,7 @@ func pathConfigRootWrite(ctx context.Context, req *logical.Request, data *framew
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -58,7 +58,7 @@ func pathRoles() *framework.Path {
}
func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entries, err := req.Storage.List("policy/")
entries, err := req.Storage.List(ctx, "policy/")
if err != nil {
return nil, err
}
@ -66,7 +66,7 @@ func (b *backend) pathRoleList(ctx context.Context, req *logical.Request, d *fra
}
func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
err := req.Storage.Delete("policy/" + d.Get("name").(string))
err := req.Storage.Delete(ctx, "policy/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@ -75,7 +75,7 @@ func pathRolesDelete(ctx context.Context, req *logical.Request, d *framework.Fie
}
func pathRolesRead(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("policy/" + d.Get("name").(string))
entry, err := req.Storage.Get(ctx, "policy/"+d.Get("name").(string))
if err != nil {
return nil, err
}
@ -125,7 +125,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
"Error compacting policy: %s", err)), nil
}
// Write the policy into storage
err := req.Storage.Put(&logical.StorageEntry{
err := req.Storage.Put(ctx, &logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: buf.Bytes(),
})
@ -134,7 +134,7 @@ func pathRolesWrite(ctx context.Context, req *logical.Request, d *framework.Fiel
}
} else {
// Write the arn ref into storage
err := req.Storage.Put(&logical.StorageEntry{
err := req.Storage.Put(ctx, &logical.StorageEntry{
Key: "policy/" + d.Get("name").(string),
Value: []byte(d.Get("arn").(string)),
})

View File

@ -15,7 +15,7 @@ func TestBackend_PathListRoles(t *testing.T) {
config.StorageView = &logical.InmemStorage{}
b := Backend()
if err := b.Setup(config); err != nil {
if err := b.Setup(context.Background(), config); err != nil {
t.Fatal(err)
}

View File

@ -45,7 +45,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
ttl := int64(d.Get("ttl").(int))
// Read the policy
policy, err := req.Storage.Get("policy/" + policyName)
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
@ -57,6 +57,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
if strings.HasPrefix(policyValue, "arn:") {
if strings.Contains(policyValue, ":role/") {
return b.assumeRole(
ctx,
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,
@ -69,6 +70,7 @@ func (b *backend) pathSTSRead(ctx context.Context, req *logical.Request, d *fram
}
// Use the helper to create the secret
return b.secretTokenCreate(
ctx,
req.Storage,
req.DisplayName, policyName, policyValue,
ttl,

View File

@ -34,7 +34,7 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
policyName := d.Get("name").(string)
// Read the policy
policy, err := req.Storage.Get("policy/" + policyName)
policy, err := req.Storage.Get(ctx, "policy/"+policyName)
if err != nil {
return nil, fmt.Errorf("error retrieving role: %s", err)
}
@ -45,10 +45,10 @@ func (b *backend) pathUserRead(ctx context.Context, req *logical.Request, d *fra
// Use the helper to create the secret
return b.secretAccessKeysCreate(
req.Storage, req.DisplayName, policyName, string(policy.Value))
ctx, req.Storage, req.DisplayName, policyName, string(policy.Value))
}
func pathUserRollback(req *logical.Request, _kind string, data interface{}) error {
func pathUserRollback(ctx context.Context, req *logical.Request, _kind string, data interface{}) error {
var entry walUser
if err := mapstructure.Decode(data, &entry); err != nil {
return err
@ -56,7 +56,7 @@ func pathUserRollback(req *logical.Request, _kind string, data interface{}) erro
username := entry.UserName
// Get the client
client, err := clientIAM(req.Storage)
client, err := clientIAM(ctx, req.Storage)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package aws
import (
"context"
"fmt"
"github.com/hashicorp/vault/logical"
@ -11,11 +12,11 @@ var walRollbackMap = map[string]framework.WALRollbackFunc{
"user": pathUserRollback,
}
func walRollback(req *logical.Request, kind string, data interface{}) error {
func walRollback(ctx context.Context, req *logical.Request, kind string, data interface{}) error {
f, ok := walRollbackMap[kind]
if !ok {
return fmt.Errorf("unknown type to rollback")
}
return f(req, kind, data)
return f(ctx, req, kind, data)
}

View File

@ -65,10 +65,10 @@ func genUsername(displayName, policyName, userType string) (ret string, warning
return
}
func (b *backend) secretTokenCreate(s logical.Storage,
func (b *backend) secretTokenCreate(ctx context.Context, s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(s)
STSClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
@ -110,10 +110,10 @@ func (b *backend) secretTokenCreate(s logical.Storage,
return resp, nil
}
func (b *backend) assumeRole(s logical.Storage,
func (b *backend) assumeRole(ctx context.Context, s logical.Storage,
displayName, policyName, policy string,
lifeTimeInSeconds int64) (*logical.Response, error) {
STSClient, err := clientSTS(s)
STSClient, err := clientSTS(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
@ -156,9 +156,10 @@ func (b *backend) assumeRole(s logical.Storage,
}
func (b *backend) secretAccessKeysCreate(
ctx context.Context,
s logical.Storage,
displayName, policyName string, policy string) (*logical.Response, error) {
client, err := clientIAM(s)
client, err := clientIAM(ctx, s)
if err != nil {
return logical.ErrorResponse(err.Error()), nil
}
@ -169,7 +170,7 @@ func (b *backend) secretAccessKeysCreate(
// the user is created because if switch the order then the WAL put
// can fail, which would put us in an awkward position: we have a user
// we need to rollback but can't put the WAL entry to do the rollback.
walId, err := framework.PutWAL(s, "user", &walUser{
walId, err := framework.PutWAL(ctx, s, "user", &walUser{
UserName: username,
})
if err != nil {
@ -221,7 +222,7 @@ func (b *backend) secretAccessKeysCreate(
// Remove the WAL entry, we succeeded! If we fail, we don't return
// the secret because it'll get rolled back anyways, so we have to return
// an error here.
if err := framework.DeleteWAL(s, walId); err != nil {
if err := framework.DeleteWAL(ctx, s, walId); err != nil {
return nil, fmt.Errorf("Failed to commit WAL entry: %s", err)
}
@ -236,7 +237,7 @@ func (b *backend) secretAccessKeysCreate(
"is_sts": false,
})
lease, err := b.Lease(s)
lease, err := b.Lease(ctx, s)
if err != nil || lease == nil {
lease = &configLease{}
}
@ -262,7 +263,7 @@ func (b *backend) secretAccessKeysRenew(ctx context.Context, req *logical.Reques
}
}
lease, err := b.Lease(req.Storage)
lease, err := b.Lease(ctx, req.Storage)
if err != nil {
return nil, err
}
@ -302,7 +303,7 @@ func secretAccessKeysRevoke(ctx context.Context, req *logical.Request, d *framew
}
// Use the user rollback mechanism to delete this user
err := pathUserRollback(req, "user", map[string]interface{}{
err := pathUserRollback(ctx, req, "user", map[string]interface{}{
"username": username,
})
if err != nil {

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"fmt"
"strings"
"sync"
@ -11,9 +12,9 @@ import (
)
// Factory creates a new backend
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) {
b := Backend()
if err := b.Setup(conf); err != nil {
if err := b.Setup(ctx, conf); err != nil {
return nil, err
}
return b, nil
@ -43,7 +44,7 @@ func Backend() *backend {
Invalidate: b.invalidate,
Clean: func() {
Clean: func(_ context.Context) {
b.ResetDB(nil)
},
BackendType: logical.TypeLogical,
@ -77,7 +78,7 @@ type sessionConfig struct {
}
// DB returns the database connection.
func (b *backend) DB(s logical.Storage) (*gocql.Session, error) {
func (b *backend) DB(ctx context.Context, s logical.Storage) (*gocql.Session, error) {
b.lock.Lock()
defer b.lock.Unlock()
@ -86,7 +87,7 @@ func (b *backend) DB(s logical.Storage) (*gocql.Session, error) {
return b.session, nil
}
entry, err := s.Get("config/connection")
entry, err := s.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
@ -120,7 +121,7 @@ func (b *backend) ResetDB(newSession *gocql.Session) {
b.session = newSession
}
func (b *backend) invalidate(key string) {
func (b *backend) invalidate(_ context.Context, key string) {
switch key {
case "config/connection":
b.ResetDB(nil)

View File

@ -1,6 +1,7 @@
package cassandra
import (
"context"
"fmt"
"log"
"os"
@ -82,7 +83,7 @@ func TestBackend_basic(t *testing.T) {
}
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
@ -106,7 +107,7 @@ func TestBackend_roleCrud(t *testing.T) {
}
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
b, err := Factory(config)
b, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}

View File

@ -87,7 +87,7 @@ take precedence.`,
}
func (b *backend) pathConnectionRead(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
entry, err := req.Storage.Get("config/connection")
entry, err := req.Storage.Get(ctx, "config/connection")
if err != nil {
return nil, err
}
@ -196,7 +196,7 @@ func (b *backend) pathConnectionWrite(ctx context.Context, req *logical.Request,
if err != nil {
return nil, err
}
if err := req.Storage.Put(entry); err != nil {
if err := req.Storage.Put(ctx, entry); err != nil {
return nil, err
}

View File

@ -36,7 +36,7 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
name := data.Get("name").(string)
// Get the role
role, err := getRole(req.Storage, name)
role, err := getRole(ctx, req.Storage, name)
if err != nil {
return nil, err
}
@ -57,7 +57,7 @@ func (b *backend) pathCredsCreateRead(ctx context.Context, req *logical.Request,
}
// Get our connection
session, err := b.DB(req.Storage)
session, err := b.DB(ctx, req.Storage)
if err != nil {
return nil, err
}

Some files were not shown because too many files have changed in this diff Show More