diff --git a/internal/pkg/uki/internal/pe/extract.go b/internal/pkg/uki/internal/pe/extract.go index bbc275f9e..fa0f172fb 100644 --- a/internal/pkg/uki/internal/pe/extract.go +++ b/internal/pkg/uki/internal/pe/extract.go @@ -32,30 +32,22 @@ func Extract(ukiPath string) (assetInfo AssetInfo, err error) { assetInfo.fileCloser = peFile - for _, section := range peFile.Sections { - // read upto section.VirtualSize bytes - sectionReader := io.NewSectionReader(section, 0, int64(section.VirtualSize)) + sectionMap := map[string]*io.Reader{ + ".initrd": &assetInfo.Initrd, + ".cmdline": &assetInfo.Cmdline, + ".linux": &assetInfo.Kernel, + } - switch section.Name { - case ".initrd": - assetInfo.Initrd = sectionReader - case ".cmdline": - assetInfo.Cmdline = sectionReader - case ".linux": - assetInfo.Kernel = sectionReader + for _, section := range peFile.Sections { + if reader, exists := sectionMap[section.Name]; exists && *reader == nil { + *reader = io.NewSectionReader(section, 0, int64(section.VirtualSize)) } } - if assetInfo.Kernel == nil { - return assetInfo, fmt.Errorf("kernel not found in PE file") - } - - if assetInfo.Initrd == nil { - return assetInfo, fmt.Errorf("initrd not found in PE file") - } - - if assetInfo.Cmdline == nil { - return assetInfo, fmt.Errorf("cmdline not found in PE file") + for name, reader := range sectionMap { + if *reader == nil { + return assetInfo, fmt.Errorf("%s not found in PE file", name) + } } return assetInfo, nil diff --git a/internal/pkg/uki/internal/pe/extract_test.go b/internal/pkg/uki/internal/pe/extract_test.go index b5baa5832..2b60d6e94 100644 --- a/internal/pkg/uki/internal/pe/extract_test.go +++ b/internal/pkg/uki/internal/pe/extract_test.go @@ -23,7 +23,7 @@ func TestUKIExtract(t *testing.T) { destFile := filepath.Join(destDir, "vmlinuz.efi") - for _, section := range []string{"linux", "initrd", "cmdline"} { + for _, section := range []string{"linux", "initrd", "cmdline", "profile-default", "profile-reset", "cmdline-reset"} { assert.NoError(t, os.WriteFile(filepath.Join(destDir, section), []byte(section), 0o644)) } @@ -46,6 +46,24 @@ func TestUKIExtract(t *testing.T) { Measure: false, Append: true, }, + { + Name: ".profile", + Path: filepath.Join(destDir, "profile-default"), + Measure: false, + Append: true, + }, + { + Name: ".profile", + Path: filepath.Join(destDir, "profile-reset"), + Measure: false, + Append: true, + }, + { + Name: ".cmdline", + Path: filepath.Join(destDir, "cmdline-reset"), + Measure: false, + Append: true, + }, })) ukiData, err := pe.Extract(destFile) diff --git a/internal/pkg/uki/internal/pe/pe_test.go b/internal/pkg/uki/internal/pe/pe_test.go index ecbe186d9..d9a38d367 100644 --- a/internal/pkg/uki/internal/pe/pe_test.go +++ b/internal/pkg/uki/internal/pe/pe_test.go @@ -145,10 +145,10 @@ func TestMultipleSections(t *testing.T) { tmpDir := t.TempDir() unamePath := filepath.Join(tmpDir, "uname") - require.NoError(t, os.WriteFile(unamePath, []byte("Talos-helloworld"), 0o644)) + require.NoError(t, os.WriteFile(unamePath, []byte("Talos"), 0o644)) unameNewPath := filepath.Join(tmpDir, "uname-new") - require.NoError(t, os.WriteFile(unameNewPath, []byte("Talos-foobar"), 0o644)) + require.NoError(t, os.WriteFile(unameNewPath, []byte("Talos-new"), 0o644)) outNative := filepath.Join(tmpDir, "uki-native.bin") @@ -174,6 +174,6 @@ func TestMultipleSections(t *testing.T) { sectionContents := extractSection(t, outNative, ".uname") - assert.Contains(t, sectionContents, "Talos-helloworld") - assert.Contains(t, sectionContents, "Talos-foobar") + assert.Contains(t, sectionContents, "Talos") + assert.Contains(t, sectionContents, "Talos-new") }