mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-02 07:52:05 +01:00
feature/tpm: protect all TPM handle operations with a mutex (#17708)
In particular on Windows, the `transport.TPMCloser` we get is not safe for concurrent use. This is especially noticeable because `tpm.attestationKey.Clone` uses the same open handle as the original key. So wrap the operations on ak.tpm with a mutex and make a deep copy with a new connection in Clone. Updates #15830 Updates #17662 Updates #17644 Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
parent
b6c6960e40
commit
f522b9dbb7
@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/google/go-tpm/tpm2"
|
||||
"github.com/google/go-tpm/tpm2/transport"
|
||||
@ -19,6 +20,7 @@ import (
|
||||
)
|
||||
|
||||
type attestationKey struct {
|
||||
tpmMu sync.Mutex
|
||||
tpm transport.TPMCloser
|
||||
// private and public parts of the TPM key as returned from tpm2.Create.
|
||||
// These are used for serialization.
|
||||
@ -144,7 +146,7 @@ type attestationKeySerialized struct {
|
||||
|
||||
// MarshalJSON implements json.Marshaler.
|
||||
func (ak *attestationKey) MarshalJSON() ([]byte, error) {
|
||||
if ak == nil || ak.IsZero() {
|
||||
if ak == nil || len(ak.tpmPublic.Bytes()) == 0 || len(ak.tpmPrivate.Buffer) == 0 {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(attestationKeySerialized{
|
||||
@ -163,6 +165,13 @@ func (ak *attestationKey) UnmarshalJSON(data []byte) (retErr error) {
|
||||
ak.tpmPrivate = tpm2.TPM2BPrivate{Buffer: aks.TPMPrivate}
|
||||
ak.tpmPublic = tpm2.BytesAs2B[tpm2.TPMTPublic, *tpm2.TPMTPublic](aks.TPMPublic)
|
||||
|
||||
ak.tpmMu.Lock()
|
||||
defer ak.tpmMu.Unlock()
|
||||
if ak.tpm != nil {
|
||||
ak.tpm.Close()
|
||||
ak.tpm = nil
|
||||
}
|
||||
|
||||
tpm, err := open()
|
||||
if err != nil {
|
||||
return key.ErrUnsupported
|
||||
@ -182,6 +191,9 @@ func (ak *attestationKey) Public() crypto.PublicKey {
|
||||
}
|
||||
|
||||
func (ak *attestationKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
||||
ak.tpmMu.Lock()
|
||||
defer ak.tpmMu.Unlock()
|
||||
|
||||
if !ak.loaded() {
|
||||
return nil, errors.New("tpm2 attestation key is not loaded during Sign")
|
||||
}
|
||||
@ -247,6 +259,9 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
|
||||
}
|
||||
|
||||
func (ak *attestationKey) Close() error {
|
||||
ak.tpmMu.Lock()
|
||||
defer ak.tpmMu.Unlock()
|
||||
|
||||
var errs []error
|
||||
if ak.handle != nil && ak.tpm != nil {
|
||||
_, err := tpm2.FlushContext{FlushHandle: ak.handle.Handle}.Execute(ak.tpm)
|
||||
@ -262,18 +277,31 @@ func (ak *attestationKey) Clone() key.HardwareAttestationKey {
|
||||
if ak == nil {
|
||||
return nil
|
||||
}
|
||||
return &attestationKey{
|
||||
tpm: ak.tpm,
|
||||
|
||||
tpm, err := open()
|
||||
if err != nil {
|
||||
log.Printf("[unexpected] failed to open a TPM connection in feature/tpm.attestationKey.Clone: %v", err)
|
||||
return nil
|
||||
}
|
||||
akc := &attestationKey{
|
||||
tpm: tpm,
|
||||
tpmPrivate: ak.tpmPrivate,
|
||||
tpmPublic: ak.tpmPublic,
|
||||
handle: ak.handle,
|
||||
pub: ak.pub,
|
||||
}
|
||||
if err := akc.load(); err != nil {
|
||||
log.Printf("[unexpected] failed to load TPM key in feature/tpm.attestationKey.Clone: %v", err)
|
||||
tpm.Close()
|
||||
return nil
|
||||
}
|
||||
return akc
|
||||
}
|
||||
|
||||
func (ak *attestationKey) IsZero() bool {
|
||||
if ak == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
ak.tpmMu.Lock()
|
||||
defer ak.tpmMu.Unlock()
|
||||
return !ak.loaded()
|
||||
}
|
||||
|
||||
@ -10,6 +10,8 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -62,6 +64,37 @@ func TestAttestationKeySign(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationKeySignConcurrent(t *testing.T) {
|
||||
skipWithoutTPM(t)
|
||||
ak, err := newAttestationKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := ak.Close(); err != nil {
|
||||
t.Errorf("ak.Close: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
data := []byte("secrets")
|
||||
digest := sha256.Sum256(data)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for range runtime.GOMAXPROCS(-1) {
|
||||
wg.Go(func() {
|
||||
// Check signature/validation round trip.
|
||||
sig, err := ak.Sign(rand.Reader, digest[:], crypto.SHA256)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ecdsa.VerifyASN1(ak.Public().(*ecdsa.PublicKey), digest[:], sig) {
|
||||
t.Errorf("ecdsa.VerifyASN1 failed")
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAttestationKeyUnmarshal(t *testing.T) {
|
||||
skipWithoutTPM(t)
|
||||
ak, err := newAttestationKey()
|
||||
@ -96,3 +129,36 @@ func TestAttestationKeyUnmarshal(t *testing.T) {
|
||||
t.Error("unmarshalled public key is not the same as the original public key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationKeyClone(t *testing.T) {
|
||||
skipWithoutTPM(t)
|
||||
ak, err := newAttestationKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ak2 := ak.Clone()
|
||||
if ak2 == nil {
|
||||
t.Fatal("Clone failed")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := ak2.Close(); err != nil {
|
||||
t.Errorf("ak2.Close: %v", err)
|
||||
}
|
||||
})
|
||||
// Close the original key, ak2 should remain open and usable.
|
||||
if err := ak.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
data := []byte("secrets")
|
||||
digest := sha256.Sum256(data)
|
||||
// Check signature/validation round trip using cloned key.
|
||||
sig, err := ak2.Sign(rand.Reader, digest[:], crypto.SHA256)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ecdsa.VerifyASN1(ak2.Public().(*ecdsa.PublicKey), digest[:], sig) {
|
||||
t.Errorf("ecdsa.VerifyASN1 failed")
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user