diff --git a/internal/app/machined/pkg/controllers/block/volume_config_test.go b/internal/app/machined/pkg/controllers/block/volume_config_test.go index 797068101..5b0f963d5 100644 --- a/internal/app/machined/pkg/controllers/block/volume_config_test.go +++ b/internal/app/machined/pkg/controllers/block/volume_config_test.go @@ -29,6 +29,7 @@ import ( "github.com/siderolabs/talos/pkg/machinery/resources/block" "github.com/siderolabs/talos/pkg/machinery/resources/config" "github.com/siderolabs/talos/pkg/machinery/resources/runtime" + "github.com/siderolabs/talos/pkg/machinery/yamlutils" ) type VolumeConfigSuite struct { @@ -192,7 +193,7 @@ func (suite *VolumeConfigSuite) TestReconcileEncryptedSTATE() { asrt.Equal(1, r.TypedSpec().Encryption.Keys[0].Slot) asrt.Equal(block.EncryptionKeyStatic, r.TypedSpec().Encryption.Keys[0].Type) - asrt.Equal([]byte("supersecret"), r.TypedSpec().Encryption.Keys[0].StaticPassphrase) + asrt.Equal(yamlutils.StringBytes([]byte("supersecret")), r.TypedSpec().Encryption.Keys[0].StaticPassphrase) asrt.Equal(2, r.TypedSpec().Encryption.Keys[1].Slot) asrt.Equal(block.EncryptionKeyTPM, r.TypedSpec().Encryption.Keys[1].Type) diff --git a/pkg/machinery/resources/block/volume_config.go b/pkg/machinery/resources/block/volume_config.go index f73c3099f..5e8576982 100644 --- a/pkg/machinery/resources/block/volume_config.go +++ b/pkg/machinery/resources/block/volume_config.go @@ -14,6 +14,7 @@ import ( "github.com/siderolabs/talos/pkg/machinery/cel" "github.com/siderolabs/talos/pkg/machinery/proto" + "github.com/siderolabs/talos/pkg/machinery/yamlutils" ) // VolumeConfigType is type of VolumeConfig resource. @@ -138,7 +139,7 @@ type EncryptionKey struct { Type EncryptionKeyType `yaml:"type" protobuf:"2"` // Only for Type == "static": - StaticPassphrase []byte `yaml:"staticPassphrase,omitempty" protobuf:"3"` + StaticPassphrase yamlutils.StringBytes `yaml:"staticPassphrase,omitempty" protobuf:"3"` // Only for Type == "kms": KMSEndpoint string `yaml:"kmsEndpoint,omitempty" protobuf:"4"` diff --git a/pkg/machinery/resources/files/etcfile_spec.go b/pkg/machinery/resources/files/etcfile_spec.go index e509141b3..a130666e9 100644 --- a/pkg/machinery/resources/files/etcfile_spec.go +++ b/pkg/machinery/resources/files/etcfile_spec.go @@ -13,6 +13,7 @@ import ( "github.com/cosi-project/runtime/pkg/resource/typed" "github.com/siderolabs/talos/pkg/machinery/proto" + "github.com/siderolabs/talos/pkg/machinery/yamlutils" ) //go:generate deep-copy -type EtcFileSpecSpec -type EtcFileStatusSpec -header-file ../../../../hack/boilerplate.txt -o deep_copy.generated.go . @@ -27,9 +28,9 @@ type EtcFileSpec = typed.Resource[EtcFileSpecSpec, EtcFileSpecExtension] // //gotagsrewrite:gen type EtcFileSpecSpec struct { - Contents []byte `yaml:"contents" protobuf:"1"` - Mode fs.FileMode `yaml:"mode" protobuf:"2"` - SelinuxLabel string `yaml:"selinux_label" protobuf:"3"` + Contents yamlutils.StringBytes `yaml:"contents" protobuf:"1"` + Mode fs.FileMode `yaml:"mode" protobuf:"2"` + SelinuxLabel string `yaml:"selinux_label" protobuf:"3"` } // NewEtcFileSpec initializes a EtcFileSpec resource. diff --git a/pkg/machinery/yamlutils/yamlutils.go b/pkg/machinery/yamlutils/yamlutils.go new file mode 100644 index 000000000..33d199370 --- /dev/null +++ b/pkg/machinery/yamlutils/yamlutils.go @@ -0,0 +1,47 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package yamlutils provides utility types to work with YAML marshaling and unmarshaling. +package yamlutils + +import "bytes" + +// StringBytes is a type that represents a byte slice as a string when marshaled to YAML. +type StringBytes []byte + +// MarshalYAML implements yaml.Marshaller interface for StringBytes. +func (s StringBytes) MarshalYAML() (any, error) { + if bytes.Equal(bytes.ToValidUTF8(s, nil), s) { + // If the byte slice is valid UTF-8, return it as a string. + return string(s), nil + } + + return s.Bytes(), nil +} + +// UnmarshalYAML implements yaml.Unmarshaler interface for StringBytes. +func (s *StringBytes) UnmarshalYAML(unmarshal func(any) error) error { + var str string + + if err := unmarshal(&str); err == nil { + *s = []byte(str) + + return nil + } + + var data []byte + + if err := unmarshal(&data); err != nil { + return err + } + + *s = data + + return nil +} + +// Bytes returns the byte slice representation of StringBytes. +func (s StringBytes) Bytes() []byte { + return []byte(s) +} diff --git a/pkg/machinery/yamlutils/yamlutils_test.go b/pkg/machinery/yamlutils/yamlutils_test.go new file mode 100644 index 000000000..fecaaa38c --- /dev/null +++ b/pkg/machinery/yamlutils/yamlutils_test.go @@ -0,0 +1,96 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package yamlutils_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" + + "github.com/siderolabs/talos/pkg/machinery/yamlutils" +) + +func TestStringBytes(t *testing.T) { + t.Parallel() + + type sbStruct struct { + Field yamlutils.StringBytes `yaml:"field"` + } + + for _, test := range []struct { + name string + + in any + expected string + + empty func() any + + // extraMarshaled is a list of strings that should be unmarshaled from YAML into the same `in` + extraMarshaled []string + }{ + { + name: "simple", + in: &sbStruct{yamlutils.StringBytes([]byte("abcde"))}, + + expected: "field: abcde\n", + empty: func() any { + return &sbStruct{} + }, + extraMarshaled: []string{ + "field:\n - 0x61\n - 0x62\n - 0x63\n - 0x64\n - 0x65\n", + "field:\n - 97\n - 98\n - 99\n - 100\n - 101\n", + }, + }, + { + name: "empty", + in: &sbStruct{yamlutils.StringBytes([]byte{})}, + + expected: "field: \"\"\n", + empty: func() any { + return &sbStruct{} + }, + }, + { + name: "invalid utf8", + in: &sbStruct{yamlutils.StringBytes([]byte{0xff})}, + + expected: "field:\n - 255\n", + empty: func() any { + return &sbStruct{} + }, + + extraMarshaled: []string{ + "field:\n - 0xff\n", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + out, err := yaml.Marshal(test.in) + require.NoError(t, err) + + assert.Equal(t, test.expected, string(out)) + + back := test.empty() + + err = yaml.Unmarshal(out, back) + require.NoError(t, err) + + assert.Equal(t, test.in, back) + + for _, extra := range test.extraMarshaled { + back := test.empty() + + err = yaml.Unmarshal([]byte(extra), back) + require.NoError(t, err) + + assert.Equal(t, test.in, back) + } + }) + } +}