mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-20 06:01:10 +02:00
The result will still pass gofmtcheck and won't trigger additional changes if someone isn't using goimports, but it will avoid the piecemeal imports changes we've been seeing.
244 lines
6.2 KiB
Go
244 lines
6.2 KiB
Go
package sink
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"os"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
hclog "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/helper/dhutil"
|
|
"github.com/hashicorp/vault/helper/jsonutil"
|
|
)
|
|
|
|
type Sink interface {
|
|
WriteToken(string) error
|
|
}
|
|
|
|
type SinkConfig struct {
|
|
Sink
|
|
Logger hclog.Logger
|
|
Config map[string]interface{}
|
|
Client *api.Client
|
|
WrapTTL time.Duration
|
|
DHType string
|
|
DHPath string
|
|
AAD string
|
|
cachedRemotePubKey []byte
|
|
cachedPubKey []byte
|
|
cachedPriKey []byte
|
|
}
|
|
|
|
type SinkServerConfig struct {
|
|
Logger hclog.Logger
|
|
Client *api.Client
|
|
Context context.Context
|
|
ExitAfterAuth bool
|
|
}
|
|
|
|
// SinkServer is responsible for pushing tokens to sinks
|
|
type SinkServer struct {
|
|
DoneCh chan struct{}
|
|
logger hclog.Logger
|
|
client *api.Client
|
|
random *rand.Rand
|
|
exitAfterAuth bool
|
|
remaining *int32
|
|
}
|
|
|
|
func NewSinkServer(conf *SinkServerConfig) *SinkServer {
|
|
ss := &SinkServer{
|
|
DoneCh: make(chan struct{}),
|
|
logger: conf.Logger,
|
|
client: conf.Client,
|
|
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
|
|
exitAfterAuth: conf.ExitAfterAuth,
|
|
remaining: new(int32),
|
|
}
|
|
|
|
return ss
|
|
}
|
|
|
|
// Run executes the server's run loop, which is responsible for reading
|
|
// in new tokens and pushing them out to the various sinks.
|
|
func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) {
|
|
if incoming == nil {
|
|
panic("incoming or shutdown channel are nil")
|
|
}
|
|
|
|
ss.logger.Info("starting sink server")
|
|
defer func() {
|
|
ss.logger.Info("sink server stopped")
|
|
close(ss.DoneCh)
|
|
}()
|
|
|
|
latestToken := new(string)
|
|
sinkCh := make(chan func() error, len(sinks))
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
|
|
case token := <-incoming:
|
|
if token != *latestToken {
|
|
|
|
// Drain the existing funcs
|
|
drainLoop:
|
|
for {
|
|
select {
|
|
case <-sinkCh:
|
|
atomic.AddInt32(ss.remaining, -1)
|
|
default:
|
|
break drainLoop
|
|
}
|
|
}
|
|
|
|
*latestToken = token
|
|
|
|
sinkFunc := func(currSink *SinkConfig, currToken string) func() error {
|
|
return func() error {
|
|
if currToken != *latestToken {
|
|
return nil
|
|
}
|
|
var err error
|
|
|
|
if currSink.WrapTTL != 0 {
|
|
if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if currSink.DHType != "" {
|
|
if currToken, err = currSink.encryptToken(currToken); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return currSink.WriteToken(currToken)
|
|
}
|
|
}
|
|
|
|
for _, s := range sinks {
|
|
atomic.AddInt32(ss.remaining, 1)
|
|
sinkCh <- sinkFunc(s, token)
|
|
}
|
|
}
|
|
|
|
case sinkFunc := <-sinkCh:
|
|
atomic.AddInt32(ss.remaining, -1)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
if err := sinkFunc(); err != nil {
|
|
backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
|
|
ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(backoff):
|
|
atomic.AddInt32(ss.remaining, 1)
|
|
sinkCh <- sinkFunc
|
|
}
|
|
} else {
|
|
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *SinkConfig) encryptToken(token string) (string, error) {
|
|
var aesKey []byte
|
|
var err error
|
|
resp := new(dhutil.Envelope)
|
|
switch s.DHType {
|
|
case "curve25519":
|
|
if len(s.cachedRemotePubKey) == 0 {
|
|
_, err = os.Lstat(s.DHPath)
|
|
if err != nil {
|
|
if !os.IsNotExist(err) {
|
|
return "", errwrap.Wrapf("error stat-ing dh parameters file: {{err}}", err)
|
|
}
|
|
return "", errors.New("no dh parameters file found, and no cached pub key")
|
|
}
|
|
fileBytes, err := ioutil.ReadFile(s.DHPath)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error reading file for dh parameters: {{err}}", err)
|
|
}
|
|
theirPubKey := new(dhutil.PublicKeyInfo)
|
|
if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
|
|
return "", errwrap.Wrapf("error decoding public key: {{err}}", err)
|
|
}
|
|
if len(theirPubKey.Curve25519PublicKey) == 0 {
|
|
return "", errors.New("public key is nil")
|
|
}
|
|
s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
|
|
}
|
|
if len(s.cachedPubKey) == 0 {
|
|
s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error generating pub/pri curve25519 keys: {{err}}", err)
|
|
}
|
|
}
|
|
resp.Curve25519PublicKey = s.cachedPubKey
|
|
}
|
|
|
|
aesKey, err = dhutil.GenerateSharedKey(s.cachedPriKey, s.cachedRemotePubKey)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error deriving shared key: {{err}}", err)
|
|
}
|
|
if len(aesKey) == 0 {
|
|
return "", errors.New("derived AES key is empty")
|
|
}
|
|
|
|
resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error encrypting with shared key: {{err}}", err)
|
|
}
|
|
m, err := jsonutil.EncodeJSON(resp)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error encoding encrypted payload: {{err}}", err)
|
|
}
|
|
|
|
return string(m), nil
|
|
}
|
|
|
|
func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
|
|
wrapClient, err := client.Clone()
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error deriving client for wrapping, not writing out to sink: {{err}})", err)
|
|
}
|
|
wrapClient.SetToken(token)
|
|
wrapClient.SetWrappingLookupFunc(func(string, string) string {
|
|
return wrapTTL.String()
|
|
})
|
|
secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
|
|
"token": token,
|
|
})
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error wrapping token, not writing out to sink: {{err}})", err)
|
|
}
|
|
if secret == nil {
|
|
return "", errors.New("nil secret returned, not writing out to sink")
|
|
}
|
|
if secret.WrapInfo == nil {
|
|
return "", errors.New("nil wrap info returned, not writing out to sink")
|
|
}
|
|
|
|
m, err := jsonutil.EncodeJSON(secret.WrapInfo)
|
|
if err != nil {
|
|
return "", errwrap.Wrapf("error marshaling token, not writing out to sink: {{err}})", err)
|
|
}
|
|
|
|
return string(m), nil
|
|
}
|