vault/physical/azure/azure.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

331 lines
8.9 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package azure
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/sdk/physical"
)
const (
// MaxBlobSize at this time
MaxBlobSize = 1024 * 1024 * 4
// MaxListResults is the current default value, setting explicitly
MaxListResults = 5000
)
// AzureBackend is a physical backend that stores data
// within an Azure blob container.
type AzureBackend struct {
container *azblob.ContainerURL
logger log.Logger
permitPool *physical.PermitPool
}
// Verify AzureBackend satisfies the correct interfaces
var _ physical.Backend = (*AzureBackend)(nil)
// NewAzureBackend constructs an Azure backend using a pre-existing
// bucket. Credentials can be provided to the backend, sourced
// from the environment, via HCL or by using managed identities.
func NewAzureBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
name := os.Getenv("AZURE_BLOB_CONTAINER")
useMSI := false
if name == "" {
name = conf["container"]
if name == "" {
return nil, fmt.Errorf("'container' must be set")
}
}
accountName := os.Getenv("AZURE_ACCOUNT_NAME")
if accountName == "" {
accountName = conf["accountName"]
if accountName == "" {
return nil, fmt.Errorf("'accountName' must be set")
}
}
accountKey := os.Getenv("AZURE_ACCOUNT_KEY")
if accountKey == "" {
accountKey = conf["accountKey"]
if accountKey == "" {
logger.Info("accountKey not set, using managed identity auth")
useMSI = true
}
}
environmentName := os.Getenv("AZURE_ENVIRONMENT")
if environmentName == "" {
environmentName = conf["environment"]
if environmentName == "" {
environmentName = "AzurePublicCloud"
}
}
environmentURL := os.Getenv("AZURE_ARM_ENDPOINT")
if environmentURL == "" {
environmentURL = conf["arm_endpoint"]
}
var environment azure.Environment
var URL *url.URL
var err error
testHost := conf["testHost"]
switch {
case testHost != "":
URL = &url.URL{Scheme: "http", Host: testHost, Path: fmt.Sprintf("/%s/%s", accountName, name)}
default:
if environmentURL != "" {
environment, err = azure.EnvironmentFromURL(environmentURL)
if err != nil {
return nil, fmt.Errorf("failed to look up Azure environment descriptor for URL %q: %w", environmentURL, err)
}
} else {
environment, err = azure.EnvironmentFromName(environmentName)
if err != nil {
return nil, fmt.Errorf("failed to look up Azure environment descriptor for name %q: %w", environmentName, err)
}
}
URL, err = url.Parse(
fmt.Sprintf("https://%s.blob.%s/%s", accountName, environment.StorageEndpointSuffix, name))
if err != nil {
return nil, fmt.Errorf("failed to create Azure client: %w", err)
}
}
var credential azblob.Credential
if useMSI {
authToken, err := getAuthTokenFromIMDS(environment.ResourceIdentifiers.Storage)
if err != nil {
return nil, fmt.Errorf("failed to obtain auth token from IMDS %q: %w", environmentName, err)
}
credential = azblob.NewTokenCredential(authToken.OAuthToken(), func(c azblob.TokenCredential) time.Duration {
err = authToken.Refresh()
if err != nil {
logger.Error("couldn't refresh token credential", "error", err)
return 0
}
expIn, err := authToken.Token().ExpiresIn.Int64()
if err != nil {
logger.Error("couldn't retrieve jwt claim for 'expiresIn' from refreshed token", "error", err)
return 0
}
logger.Debug("token refreshed, new token expires in", "access_token_expiry", expIn)
c.SetToken(authToken.OAuthToken())
// tokens are valid for 23h59m (86399s) by default, refresh after ~21h
return time.Duration(int(float64(expIn)*0.9)) * time.Second
})
} else {
credential, err = azblob.NewSharedKeyCredential(accountName, accountKey)
if err != nil {
return nil, fmt.Errorf("failed to create Azure client: %w", err)
}
}
p := azblob.NewPipeline(credential, azblob.PipelineOptions{})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
containerURL := azblob.NewContainerURL(*URL, p)
_, err = containerURL.GetProperties(ctx, azblob.LeaseAccessConditions{})
if err != nil {
var e azblob.StorageError
if errors.As(err, &e) {
switch e.ServiceCode() {
case azblob.ServiceCodeContainerNotFound:
_, err := containerURL.Create(ctx, azblob.Metadata{}, azblob.PublicAccessNone)
if err != nil {
return nil, fmt.Errorf("failed to create %q container: %w", name, err)
}
default:
return nil, fmt.Errorf("failed to get properties for container %q: %w", name, err)
}
}
}
maxParStr, ok := conf["max_parallel"]
var maxParInt int
if ok {
maxParInt, err = strconv.Atoi(maxParStr)
if err != nil {
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
}
if logger.IsDebug() {
logger.Debug("max_parallel set", "max_parallel", maxParInt)
}
}
a := &AzureBackend{
container: &containerURL,
logger: logger,
permitPool: physical.NewPermitPool(maxParInt),
}
return a, nil
}
// Put is used to insert or update an entry
func (a *AzureBackend) Put(ctx context.Context, entry *physical.Entry) error {
defer metrics.MeasureSince([]string{"azure", "put"}, time.Now())
if len(entry.Value) >= MaxBlobSize {
return fmt.Errorf("value is bigger than the current supported limit of 4MBytes")
}
a.permitPool.Acquire()
defer a.permitPool.Release()
blobURL := a.container.NewBlockBlobURL(entry.Key)
_, err := azblob.UploadBufferToBlockBlob(ctx, entry.Value, blobURL, azblob.UploadToBlockBlobOptions{
BlockSize: MaxBlobSize,
})
return err
}
// Get is used to fetch an entry
func (a *AzureBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
defer metrics.MeasureSince([]string{"azure", "get"}, time.Now())
a.permitPool.Acquire()
defer a.permitPool.Release()
blobURL := a.container.NewBlockBlobURL(key)
clientOptions := azblob.ClientProvidedKeyOptions{}
res, err := blobURL.Download(ctx, 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false, clientOptions)
if err != nil {
var e azblob.StorageError
if errors.As(err, &e) {
switch e.ServiceCode() {
case azblob.ServiceCodeBlobNotFound:
return nil, nil
default:
return nil, fmt.Errorf("failed to download blob %q: %w", key, err)
}
}
return nil, err
}
reader := res.Body(azblob.RetryReaderOptions{})
defer reader.Close()
data, err := ioutil.ReadAll(reader)
ent := &physical.Entry{
Key: key,
Value: data,
}
return ent, err
}
// Delete is used to permanently delete an entry
func (a *AzureBackend) Delete(ctx context.Context, key string) error {
defer metrics.MeasureSince([]string{"azure", "delete"}, time.Now())
a.permitPool.Acquire()
defer a.permitPool.Release()
blobURL := a.container.NewBlockBlobURL(key)
_, err := blobURL.Delete(ctx, azblob.DeleteSnapshotsOptionInclude, azblob.BlobAccessConditions{})
if err != nil {
var e azblob.StorageError
if errors.As(err, &e) {
switch e.ServiceCode() {
case azblob.ServiceCodeBlobNotFound:
return nil
default:
return fmt.Errorf("failed to delete blob %q: %w", key, err)
}
}
}
return err
}
// List is used to list all the keys under a given
// prefix, up to the next prefix.
func (a *AzureBackend) List(ctx context.Context, prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"azure", "list"}, time.Now())
a.permitPool.Acquire()
defer a.permitPool.Release()
var keys []string
for marker := (azblob.Marker{}); marker.NotDone(); {
listBlob, err := a.container.ListBlobsFlatSegment(ctx, marker, azblob.ListBlobsSegmentOptions{
Prefix: prefix,
MaxResults: MaxListResults,
})
if err != nil {
return nil, err
}
for _, blobInfo := range listBlob.Segment.BlobItems {
key := strings.TrimPrefix(blobInfo.Name, prefix)
if i := strings.Index(key, "/"); i == -1 {
// file
keys = append(keys, key)
} else {
// subdirectory
keys = strutil.AppendIfMissing(keys, key[:i+1])
}
}
marker = listBlob.NextMarker
}
sort.Strings(keys)
return keys, nil
}
// getAuthTokenFromIMDS uses the Azure Instance Metadata Service to retrieve a short-lived credential using OAuth
// more info on this https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview
func getAuthTokenFromIMDS(resource string) (*adal.ServicePrincipalToken, error) {
msiEndpoint, err := adal.GetMSIVMEndpoint()
if err != nil {
return nil, err
}
spt, err := adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource)
if err != nil {
return nil, err
}
if err := spt.Refresh(); err != nil {
return nil, err
}
token := spt.Token()
if token.IsZero() {
return nil, err
}
return spt, nil
}