From eb2b905af0a509ebe7f174694643b62be6eb4d0a Mon Sep 17 00:00:00 2001 From: miagilepner Date: Thu, 1 Feb 2024 13:40:15 +0100 Subject: [PATCH] Support adding new stubs to existing stub files (#25130) * stubmaker can generate stubs for only the missing functions * check error --- tools/stubmaker/main.go | 181 ++++++++++++++++++++++++---------------- 1 file changed, 110 insertions(+), 71 deletions(-) diff --git a/tools/stubmaker/main.go b/tools/stubmaker/main.go index 53676e08b8..187ca219e2 100644 --- a/tools/stubmaker/main.go +++ b/tools/stubmaker/main.go @@ -8,20 +8,22 @@ import ( "bytes" "errors" "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" "go/types" - "io" "os" "path/filepath" - "regexp" - "sort" "strings" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/google/go-cmp/cmp" "github.com/hashicorp/go-hclog" + "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" + "golang.org/x/tools/imports" ) var logger hclog.Logger @@ -31,6 +33,11 @@ func fatal(err error) { os.Exit(1) } +type generator struct { + file *ast.File + fset *token.FileSet +} + func main() { logger = hclog.New(&hclog.LoggerOptions{ Name: "stubmaker", @@ -67,14 +74,15 @@ func main() { fatal(err) } - inputLines, err := readLines(bytes.NewBuffer(b)) + inputParsed, err := parseFile(b) if err != nil { fatal(err) } - funcs := getFuncs(inputLines) - if needed, err := isStubNeeded(funcs); err != nil { + needed, existing, err := inputParsed.areStubsNeeded() + if err != nil { fatal(err) - } else if !needed { + } + if !needed { return } @@ -107,7 +115,7 @@ func main() { if err != nil { fatal(err) } - _, err = io.WriteString(output, strings.Join(getOutput(inputLines), "\n")+"\n") + err = inputParsed.writeStubs(output, existing) if err != nil { // If we don't end up writing to the file, delete it. os.Remove(outputFile + ".tmp") @@ -119,6 +127,57 @@ func main() { } } +func (g *generator) writeStubs(output *os.File, existingFuncs map[string]struct{}) error { + // delete all functions/methods that are already defined + g.modifyAST(existingFuncs) + + // write the updated code to buf + buf := new(bytes.Buffer) + err := format.Node(buf, g.fset, g.file) + if err != nil { + return err + } + + // remove any unneeded imports + res, err := imports.Process("", buf.Bytes(), &imports.Options{ + Fragment: true, + AllErrors: false, + Comments: true, + FormatOnly: false, + }) + if err != nil { + return err + } + + // add the code generation line and update the build tags + outputLines, err := fixGeneratedComments(res) + if err != nil { + return err + } + _, err = output.WriteString(strings.Join(outputLines, "\n") + "\n") + return err +} + +func fixGeneratedComments(b []byte) ([]string, error) { + warning := "// Code generated by tools/stubmaker; DO NOT EDIT." + goGenerate := "//go:generate go run github.com/hashicorp/vault/tools/stubmaker" + + scanner := bufio.NewScanner(bytes.NewBuffer(b)) + var outputLines []string + for scanner.Scan() { + line := scanner.Text() + switch { + case strings.Contains(line, "//go:build ") && strings.Contains(line, "!enterprise"): + outputLines = append(outputLines, warning, "") + line = strings.ReplaceAll(line, "!enterprise", "enterprise") + case line == goGenerate: + continue + } + outputLines = append(outputLines, line) + } + return outputLines, scanner.Err() +} + func inGit(wt *git.Worktree, st git.Status, obj object.Object, path string) (bool, error) { absPath, err := filepath.Abs(path) if err != nil { @@ -189,27 +248,24 @@ func resolve(obj object.Object, path string) (*object.Blob, error) { } } -func readLines(r io.Reader) ([]string, error) { - scanner := bufio.NewScanner(r) - scanner.Split(bufio.ScanLines) - - var lines []string - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - if err := scanner.Err(); err != nil { - return nil, err - } - return lines, nil -} - -func isStubNeeded(funcs []string) (bool, error) { +// areStubsNeeded checks if all functions and methods defined in the stub file +// are present in the package +func (g *generator) areStubsNeeded() (needed bool, existingStubs map[string]struct{}, err error) { pkg, err := parsePackage(".", []string{"enterprise"}) if err != nil { - return false, err + return false, nil, err } - var found []string + stubFunctions := make(map[string]struct{}) + for _, d := range g.file.Decls { + dFunc, ok := d.(*ast.FuncDecl) + if !ok { + continue + } + stubFunctions[dFunc.Name.Name] = struct{}{} + + } + found := make(map[string]struct{}) for name, val := range pkg.TypesInfo.Defs { if val == nil { continue @@ -218,54 +274,25 @@ func isStubNeeded(funcs []string) (bool, error) { if !ok { continue } - for _, f := range funcs { - if name.Name == f { - found = append(found, f) + if _, ok := stubFunctions[name.Name]; ok { + found[name.Name] = struct{}{} + } + } + + return len(found) != len(stubFunctions), found, nil +} + +func (g *generator) modifyAST(exists map[string]struct{}) { + astutil.Apply(g.file, nil, func(c *astutil.Cursor) bool { + switch x := c.Node().(type) { + case *ast.FuncDecl: + if _, ok := exists[x.Name.Name]; ok { + c.Delete() } } - } - switch { - case len(found) == len(funcs): - return false, nil - case len(found) != 0: - sort.Strings(found) - sort.Strings(funcs) - delta := cmp.Diff(found, funcs) - return false, fmt.Errorf("funcs partially defined, delta=%s", delta) - } - return true, nil -} - -var funcRE = regexp.MustCompile("^func *(?:[(][^)]+[)])? *([^(]+)") - -func getFuncs(inputLines []string) []string { - var funcs []string - for _, line := range inputLines { - matches := funcRE.FindStringSubmatch(line) - if len(matches) > 1 { - funcs = append(funcs, matches[1]) - } - } - return funcs -} - -func getOutput(inputLines []string) []string { - warning := "// Code generated by tools/stubmaker; DO NOT EDIT." - - var outputLines []string - for _, line := range inputLines { - switch line { - case "//go:build !enterprise": - outputLines = append(outputLines, warning, "") - line = "//go:build enterprise" - case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker": - continue - } - outputLines = append(outputLines, line) - } - - return outputLines + return true + }) } func parsePackage(name string, tags []string) (*packages.Package, error) { @@ -283,3 +310,15 @@ func parsePackage(name string, tags []string) (*packages.Package, error) { } return pkgs[0], nil } + +func parseFile(buffer []byte) (*generator, error) { + fs := token.NewFileSet() + f, err := parser.ParseFile(fs, "", buffer, parser.AllErrors|parser.ParseComments) + if err != nil { + return nil, err + } + return &generator{ + file: f, + fset: fs, + }, nil +}