diff --git a/feature/tpm/attestation.go b/feature/tpm/attestation.go index 5fbda3b17..597d4a649 100644 --- a/feature/tpm/attestation.go +++ b/feature/tpm/attestation.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "log" + "sync" "github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2/transport" @@ -19,7 +20,8 @@ import ( ) type attestationKey struct { - tpm transport.TPMCloser + tpmMu sync.Mutex + tpm transport.TPMCloser // private and public parts of the TPM key as returned from tpm2.Create. // These are used for serialization. tpmPrivate tpm2.TPM2BPrivate @@ -144,7 +146,7 @@ type attestationKeySerialized struct { // MarshalJSON implements json.Marshaler. func (ak *attestationKey) MarshalJSON() ([]byte, error) { - if ak == nil || ak.IsZero() { + if ak == nil || len(ak.tpmPublic.Bytes()) == 0 || len(ak.tpmPrivate.Buffer) == 0 { return []byte("null"), nil } return json.Marshal(attestationKeySerialized{ @@ -163,6 +165,13 @@ func (ak *attestationKey) UnmarshalJSON(data []byte) (retErr error) { ak.tpmPrivate = tpm2.TPM2BPrivate{Buffer: aks.TPMPrivate} ak.tpmPublic = tpm2.BytesAs2B[tpm2.TPMTPublic, *tpm2.TPMTPublic](aks.TPMPublic) + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + if ak.tpm != nil { + ak.tpm.Close() + ak.tpm = nil + } + tpm, err := open() if err != nil { return key.ErrUnsupported @@ -182,6 +191,9 @@ func (ak *attestationKey) Public() crypto.PublicKey { } func (ak *attestationKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + if !ak.loaded() { return nil, errors.New("tpm2 attestation key is not loaded during Sign") } @@ -247,6 +259,9 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { } func (ak *attestationKey) Close() error { + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + var errs []error if ak.handle != nil && ak.tpm != nil { _, err := tpm2.FlushContext{FlushHandle: ak.handle.Handle}.Execute(ak.tpm) @@ -262,18 +277,31 @@ func (ak *attestationKey) Clone() key.HardwareAttestationKey { if ak == nil { return nil } - return &attestationKey{ - tpm: ak.tpm, + + tpm, err := open() + if err != nil { + log.Printf("[unexpected] failed to open a TPM connection in feature/tpm.attestationKey.Clone: %v", err) + return nil + } + akc := &attestationKey{ + tpm: tpm, tpmPrivate: ak.tpmPrivate, tpmPublic: ak.tpmPublic, - handle: ak.handle, - pub: ak.pub, } + if err := akc.load(); err != nil { + log.Printf("[unexpected] failed to load TPM key in feature/tpm.attestationKey.Clone: %v", err) + tpm.Close() + return nil + } + return akc } func (ak *attestationKey) IsZero() bool { if ak == nil { return true } + + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() return !ak.loaded() } diff --git a/feature/tpm/attestation_test.go b/feature/tpm/attestation_test.go index ead88c955..e7ff72987 100644 --- a/feature/tpm/attestation_test.go +++ b/feature/tpm/attestation_test.go @@ -10,6 +10,8 @@ import ( "crypto/rand" "crypto/sha256" "encoding/json" + "runtime" + "sync" "testing" ) @@ -62,6 +64,37 @@ func TestAttestationKeySign(t *testing.T) { } } +func TestAttestationKeySignConcurrent(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak.Close(); err != nil { + t.Errorf("ak.Close: %v", err) + } + }) + + data := []byte("secrets") + digest := sha256.Sum256(data) + + wg := sync.WaitGroup{} + for range runtime.GOMAXPROCS(-1) { + wg.Go(func() { + // Check signature/validation round trip. + sig, err := ak.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } + }) + } + wg.Wait() +} + func TestAttestationKeyUnmarshal(t *testing.T) { skipWithoutTPM(t) ak, err := newAttestationKey() @@ -96,3 +129,36 @@ func TestAttestationKeyUnmarshal(t *testing.T) { t.Error("unmarshalled public key is not the same as the original public key") } } + +func TestAttestationKeyClone(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + + ak2 := ak.Clone() + if ak2 == nil { + t.Fatal("Clone failed") + } + t.Cleanup(func() { + if err := ak2.Close(); err != nil { + t.Errorf("ak2.Close: %v", err) + } + }) + // Close the original key, ak2 should remain open and usable. + if err := ak.Close(); err != nil { + t.Fatal(err) + } + + data := []byte("secrets") + digest := sha256.Sum256(data) + // Check signature/validation round trip using cloned key. + sig, err := ak2.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak2.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } +}