diff --git a/command/transit_import_key.go b/command/transit_import_key.go index 8d9076bdc7..d8dd60b074 100644 --- a/command/transit_import_key.go +++ b/command/transit_import_key.go @@ -134,7 +134,7 @@ func ImportKey(c *BaseCommand, operation string, pathFunc ImportKeyFunc, flags * } // Fetch the wrapping key c.UI.Output("Retrieving wrapping key.") - wrappingKey, err := fetchWrappingKey(c, client, path) + wrappingKey, err := fetchWrappingKey(client, path) if err != nil { c.UI.Error(fmt.Sprintf("failed to fetch wrapping key: %v", err)) return 3 @@ -154,7 +154,7 @@ func ImportKey(c *BaseCommand, operation string, pathFunc ImportKeyFunc, flags * wrappedAESKey, err := rsa.EncryptOAEP( sha256.New(), rand.Reader, - wrappingKey.(*rsa.PublicKey), + wrappingKey, ephemeralAESKey, []byte{}, ) @@ -190,7 +190,7 @@ func ImportKey(c *BaseCommand, operation string, pathFunc ImportKeyFunc, flags * } } -func fetchWrappingKey(c *BaseCommand, client *api.Client, path string) (any, error) { +func fetchWrappingKey(client *api.Client, path string) (*rsa.PublicKey, error) { resp, err := client.Logical().Read(path + "/wrapping_key") if err != nil { return nil, fmt.Errorf("error fetching wrapping key: %w", err) @@ -200,12 +200,19 @@ func fetchWrappingKey(c *BaseCommand, client *api.Client, path string) (any, err } key, ok := resp.Data["public_key"] if !ok { - c.UI.Error("could not find wrapping key") + return nil, fmt.Errorf("missing public_key field in response") } keyBlock, _ := pem.Decode([]byte(key.(string))) + if keyBlock == nil { + return nil, fmt.Errorf("failed to decode PEM information from public_key response field") + } parsedKey, err := x509.ParsePKIXPublicKey(keyBlock.Bytes) if err != nil { return nil, fmt.Errorf("error parsing wrapping key: %w", err) } - return parsedKey, nil + rsaKey, ok := parsedKey.(*rsa.PublicKey) + if !ok { + return nil, fmt.Errorf("returned value was not an RSA public key but a %T", rsaKey) + } + return rsaKey, nil }