mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-04 20:06:27 +02:00
Support adding new stubs to existing stub files (#25130)
* stubmaker can generate stubs for only the missing functions * check error
This commit is contained in:
parent
f0e7f114a1
commit
eb2b905af0
@ -8,20 +8,22 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"go/types"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/go-git/go-git/v5"
|
||||
"github.com/go-git/go-git/v5/plumbing"
|
||||
"github.com/go-git/go-git/v5/plumbing/object"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"golang.org/x/tools/go/ast/astutil"
|
||||
"golang.org/x/tools/go/packages"
|
||||
"golang.org/x/tools/imports"
|
||||
)
|
||||
|
||||
var logger hclog.Logger
|
||||
@ -31,6 +33,11 @@ func fatal(err error) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
type generator struct {
|
||||
file *ast.File
|
||||
fset *token.FileSet
|
||||
}
|
||||
|
||||
func main() {
|
||||
logger = hclog.New(&hclog.LoggerOptions{
|
||||
Name: "stubmaker",
|
||||
@ -67,14 +74,15 @@ func main() {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
inputLines, err := readLines(bytes.NewBuffer(b))
|
||||
inputParsed, err := parseFile(b)
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
funcs := getFuncs(inputLines)
|
||||
if needed, err := isStubNeeded(funcs); err != nil {
|
||||
needed, existing, err := inputParsed.areStubsNeeded()
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
} else if !needed {
|
||||
}
|
||||
if !needed {
|
||||
return
|
||||
}
|
||||
|
||||
@ -107,7 +115,7 @@ func main() {
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
_, err = io.WriteString(output, strings.Join(getOutput(inputLines), "\n")+"\n")
|
||||
err = inputParsed.writeStubs(output, existing)
|
||||
if err != nil {
|
||||
// If we don't end up writing to the file, delete it.
|
||||
os.Remove(outputFile + ".tmp")
|
||||
@ -119,6 +127,57 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *generator) writeStubs(output *os.File, existingFuncs map[string]struct{}) error {
|
||||
// delete all functions/methods that are already defined
|
||||
g.modifyAST(existingFuncs)
|
||||
|
||||
// write the updated code to buf
|
||||
buf := new(bytes.Buffer)
|
||||
err := format.Node(buf, g.fset, g.file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove any unneeded imports
|
||||
res, err := imports.Process("", buf.Bytes(), &imports.Options{
|
||||
Fragment: true,
|
||||
AllErrors: false,
|
||||
Comments: true,
|
||||
FormatOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add the code generation line and update the build tags
|
||||
outputLines, err := fixGeneratedComments(res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = output.WriteString(strings.Join(outputLines, "\n") + "\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func fixGeneratedComments(b []byte) ([]string, error) {
|
||||
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
|
||||
goGenerate := "//go:generate go run github.com/hashicorp/vault/tools/stubmaker"
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewBuffer(b))
|
||||
var outputLines []string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
switch {
|
||||
case strings.Contains(line, "//go:build ") && strings.Contains(line, "!enterprise"):
|
||||
outputLines = append(outputLines, warning, "")
|
||||
line = strings.ReplaceAll(line, "!enterprise", "enterprise")
|
||||
case line == goGenerate:
|
||||
continue
|
||||
}
|
||||
outputLines = append(outputLines, line)
|
||||
}
|
||||
return outputLines, scanner.Err()
|
||||
}
|
||||
|
||||
func inGit(wt *git.Worktree, st git.Status, obj object.Object, path string) (bool, error) {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
@ -189,27 +248,24 @@ func resolve(obj object.Object, path string) (*object.Blob, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func readLines(r io.Reader) ([]string, error) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
var lines []string
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return lines, nil
|
||||
}
|
||||
|
||||
func isStubNeeded(funcs []string) (bool, error) {
|
||||
// areStubsNeeded checks if all functions and methods defined in the stub file
|
||||
// are present in the package
|
||||
func (g *generator) areStubsNeeded() (needed bool, existingStubs map[string]struct{}, err error) {
|
||||
pkg, err := parsePackage(".", []string{"enterprise"})
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
var found []string
|
||||
stubFunctions := make(map[string]struct{})
|
||||
for _, d := range g.file.Decls {
|
||||
dFunc, ok := d.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
stubFunctions[dFunc.Name.Name] = struct{}{}
|
||||
|
||||
}
|
||||
found := make(map[string]struct{})
|
||||
for name, val := range pkg.TypesInfo.Defs {
|
||||
if val == nil {
|
||||
continue
|
||||
@ -218,54 +274,25 @@ func isStubNeeded(funcs []string) (bool, error) {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, f := range funcs {
|
||||
if name.Name == f {
|
||||
found = append(found, f)
|
||||
if _, ok := stubFunctions[name.Name]; ok {
|
||||
found[name.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return len(found) != len(stubFunctions), found, nil
|
||||
}
|
||||
|
||||
func (g *generator) modifyAST(exists map[string]struct{}) {
|
||||
astutil.Apply(g.file, nil, func(c *astutil.Cursor) bool {
|
||||
switch x := c.Node().(type) {
|
||||
case *ast.FuncDecl:
|
||||
if _, ok := exists[x.Name.Name]; ok {
|
||||
c.Delete()
|
||||
}
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case len(found) == len(funcs):
|
||||
return false, nil
|
||||
case len(found) != 0:
|
||||
sort.Strings(found)
|
||||
sort.Strings(funcs)
|
||||
delta := cmp.Diff(found, funcs)
|
||||
return false, fmt.Errorf("funcs partially defined, delta=%s", delta)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var funcRE = regexp.MustCompile("^func *(?:[(][^)]+[)])? *([^(]+)")
|
||||
|
||||
func getFuncs(inputLines []string) []string {
|
||||
var funcs []string
|
||||
for _, line := range inputLines {
|
||||
matches := funcRE.FindStringSubmatch(line)
|
||||
if len(matches) > 1 {
|
||||
funcs = append(funcs, matches[1])
|
||||
}
|
||||
}
|
||||
return funcs
|
||||
}
|
||||
|
||||
func getOutput(inputLines []string) []string {
|
||||
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
|
||||
|
||||
var outputLines []string
|
||||
for _, line := range inputLines {
|
||||
switch line {
|
||||
case "//go:build !enterprise":
|
||||
outputLines = append(outputLines, warning, "")
|
||||
line = "//go:build enterprise"
|
||||
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
|
||||
continue
|
||||
}
|
||||
outputLines = append(outputLines, line)
|
||||
}
|
||||
|
||||
return outputLines
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func parsePackage(name string, tags []string) (*packages.Package, error) {
|
||||
@ -283,3 +310,15 @@ func parsePackage(name string, tags []string) (*packages.Package, error) {
|
||||
}
|
||||
return pkgs[0], nil
|
||||
}
|
||||
|
||||
func parseFile(buffer []byte) (*generator, error) {
|
||||
fs := token.NewFileSet()
|
||||
f, err := parser.ParseFile(fs, "", buffer, parser.AllErrors|parser.ParseComments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &generator{
|
||||
file: f,
|
||||
fset: fs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user