2025-04-25 16:12:13 +00:00

116 lines
2.4 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package managed_key
import (
"context"
"crypto"
"encoding/asn1"
"fmt"
"io"
"math/big"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/crypto/ssh"
)
type ManagedKeyInfo struct {
publicKey ssh.PublicKey
Name NameKey
Uuid UUIDKey
ctx context.Context
mkv SSHManagedKeyView
}
type managedKeyId interface {
String() string
}
type SSHManagedKeyView interface {
BackendUUID() string
GetManagedKeyView() (logical.ManagedKeySystemView, error)
GetRandomReader() io.Reader
}
type (
UUIDKey string
NameKey string
)
type ecdsaSignature struct {
R, S *big.Int
}
var HashFunctions = map[string]crypto.Hash{
ssh.KeyAlgoECDSA256: crypto.SHA256,
ssh.KeyAlgoECDSA384: crypto.SHA384,
ssh.KeyAlgoECDSA521: crypto.SHA512,
ssh.KeyAlgoRSA: crypto.SHA1,
ssh.KeyAlgoRSASHA256: crypto.SHA256,
ssh.KeyAlgoRSASHA512: crypto.SHA512,
}
func (u UUIDKey) String() string {
return string(u)
}
func (n NameKey) String() string {
return string(n)
}
func (m ManagedKeyInfo) PublicKey() ssh.PublicKey {
return m.publicKey
}
func (m ManagedKeyInfo) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
return m.SignWithAlgorithm(rand, data, "")
}
func (m ManagedKeyInfo) SignWithAlgorithm(rand io.Reader, input []byte, algorithm string) (*ssh.Signature, error) {
keyType := m.publicKey.Type()
if keyType != ssh.KeyAlgoRSA && (algorithm == ssh.KeyAlgoRSASHA256 || algorithm == ssh.KeyAlgoRSASHA512) {
return nil, fmt.Errorf("invalid algorithm %s for key type %s", algorithm, keyType)
}
if algorithm == "default" || algorithm == "" {
if keyType == ssh.KeyAlgoRSA {
algorithm = ssh.KeyAlgoRSASHA256
} else {
algorithm = keyType
}
}
opts, ok := HashFunctions[algorithm]
if !ok {
return nil, fmt.Errorf("invalid hash algorithm %s", algorithm)
}
hashFcn := opts.New()
hashFcn.Write(input)
digest := hashFcn.Sum(nil)
sig, err := Sign(m.ctx, m.mkv, m.Uuid, rand, digest, opts)
if err != nil {
return nil, err
}
var sigBytes []byte
switch keyType {
case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
var ecdsaSig ecdsaSignature
_, err = asn1.Unmarshal(sig, &ecdsaSig)
if err != nil {
return nil, fmt.Errorf("error unmarshalling signature: %w", err)
}
sigBytes = ssh.Marshal(ecdsaSig)
case ssh.KeyAlgoRSA:
sigBytes = sig
}
return &ssh.Signature{
Format: algorithm,
Blob: sigBytes,
}, nil
}