mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-22 07:01:09 +02:00
333 lines
7.5 KiB
Go
333 lines
7.5 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"go/types"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/go-git/go-git/v5"
|
|
"github.com/go-git/go-git/v5/plumbing"
|
|
"github.com/go-git/go-git/v5/plumbing/object"
|
|
"github.com/hashicorp/go-hclog"
|
|
"golang.org/x/tools/go/ast/astutil"
|
|
"golang.org/x/tools/go/packages"
|
|
"golang.org/x/tools/imports"
|
|
)
|
|
|
|
var logger hclog.Logger
|
|
|
|
func fatal(err error) {
|
|
logger.Error("fatal error", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
type generator struct {
|
|
file *ast.File
|
|
fset *token.FileSet
|
|
}
|
|
|
|
func main() {
|
|
logger = hclog.New(&hclog.LoggerOptions{
|
|
Name: "stubmaker",
|
|
Level: hclog.Trace,
|
|
})
|
|
|
|
// Setup git, both so we can determine if we're running on enterprise, and
|
|
// so we can make sure we don't clobber a non-transient file.
|
|
repo, err := git.PlainOpenWithOptions(".", &git.PlainOpenOptions{
|
|
DetectDotGit: true,
|
|
})
|
|
if err != nil {
|
|
if err.Error() != "repository does not exist" {
|
|
fatal(err)
|
|
}
|
|
repo = nil
|
|
}
|
|
|
|
var wt *git.Worktree
|
|
if repo != nil {
|
|
wt, err = repo.Worktree()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
if !isEnterprise(wt) {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Read the file and figure out if we need to do anything.
|
|
inputFile := os.Getenv("GOFILE")
|
|
if !strings.HasSuffix(inputFile, "_stubs_oss.go") {
|
|
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _stubs_oss.go"))
|
|
}
|
|
|
|
baseFilename := strings.TrimSuffix(inputFile, "_stubs_oss.go")
|
|
outputFile := baseFilename + "_stubs_ent.go"
|
|
b, err := os.ReadFile(inputFile)
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
inputParsed, err := parseFile(b)
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
needed, existing, err := inputParsed.areStubsNeeded()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
if !needed {
|
|
return
|
|
}
|
|
|
|
// We'd like to write the file, but first make sure that we're not going
|
|
// to blow away anyone's work or overwrite a file already in git.
|
|
if repo != nil {
|
|
head, err := repo.Head()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
obj, err := repo.Object(plumbing.AnyObject, head.Hash())
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
st, err := wt.Status()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
tracked, err := inGit(wt, st, obj, outputFile)
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
if tracked {
|
|
fatal(fmt.Errorf("output file %s exists in git, not overwriting", outputFile))
|
|
}
|
|
}
|
|
|
|
// Now we can finally write the file
|
|
output, err := os.Create(outputFile + ".tmp")
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
err = inputParsed.writeStubs(output, existing)
|
|
if err != nil {
|
|
// If we don't end up writing to the file, delete it.
|
|
os.Remove(outputFile + ".tmp")
|
|
} else {
|
|
os.Rename(outputFile+".tmp", outputFile)
|
|
}
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
}
|
|
|
|
func (g *generator) writeStubs(output *os.File, existingFuncs map[string]struct{}) error {
|
|
// delete all functions/methods that are already defined
|
|
g.modifyAST(existingFuncs)
|
|
|
|
// write the updated code to buf
|
|
buf := new(bytes.Buffer)
|
|
err := format.Node(buf, g.fset, g.file)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// remove any unneeded imports
|
|
res, err := imports.Process("", buf.Bytes(), &imports.Options{
|
|
Fragment: true,
|
|
AllErrors: false,
|
|
Comments: true,
|
|
FormatOnly: false,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// add the code generation line and update the build tags
|
|
outputLines, err := fixGeneratedComments(res)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = output.WriteString(strings.Join(outputLines, "\n") + "\n")
|
|
return err
|
|
}
|
|
|
|
func fixGeneratedComments(b []byte) ([]string, error) {
|
|
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
|
|
goGenerate := "//go:generate go run github.com/hashicorp/vault/tools/stubmaker"
|
|
|
|
scanner := bufio.NewScanner(bytes.NewBuffer(b))
|
|
var outputLines []string
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
switch {
|
|
case strings.Contains(line, "//go:build ") && strings.Contains(line, "!enterprise"):
|
|
outputLines = append(outputLines, warning, "")
|
|
line = strings.ReplaceAll(line, "!enterprise", "enterprise")
|
|
case line == goGenerate:
|
|
continue
|
|
}
|
|
outputLines = append(outputLines, line)
|
|
}
|
|
return outputLines, scanner.Err()
|
|
}
|
|
|
|
func inGit(wt *git.Worktree, st git.Status, obj object.Object, path string) (bool, error) {
|
|
absPath, err := filepath.Abs(path)
|
|
if err != nil {
|
|
return false, fmt.Errorf("path %s can't be made absolute: %w", path, err)
|
|
}
|
|
relPath, err := filepath.Rel(wt.Filesystem.Root(), absPath)
|
|
if err != nil {
|
|
return false, fmt.Errorf("path %s can't be made relative: %w", absPath, err)
|
|
}
|
|
|
|
fst := st.File(relPath)
|
|
if fst.Worktree != git.Untracked || fst.Staging != git.Untracked {
|
|
return true, nil
|
|
}
|
|
|
|
curwd, err := os.Getwd()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
blob, err := resolve(obj, relPath)
|
|
if err != nil && !strings.Contains(err.Error(), "file not found") {
|
|
return false, fmt.Errorf("error resolving path %s from %s: %w", relPath, curwd, err)
|
|
}
|
|
|
|
return blob != nil, nil
|
|
}
|
|
|
|
func isEnterprise(wt *git.Worktree) bool {
|
|
st, err := wt.Filesystem.Stat("enthelpers")
|
|
onOss := errors.Is(err, os.ErrNotExist)
|
|
onEnt := st != nil
|
|
|
|
switch {
|
|
case onOss && !onEnt:
|
|
case !onOss && onEnt:
|
|
default:
|
|
fatal(err)
|
|
}
|
|
return onEnt
|
|
}
|
|
|
|
// resolve blob at given path from obj. obj can be a commit, tag, tree, or blob.
|
|
func resolve(obj object.Object, path string) (*object.Blob, error) {
|
|
switch o := obj.(type) {
|
|
case *object.Commit:
|
|
t, err := o.Tree()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resolve(t, path)
|
|
case *object.Tag:
|
|
target, err := o.Object()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resolve(target, path)
|
|
case *object.Tree:
|
|
file, err := o.File(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &file.Blob, nil
|
|
case *object.Blob:
|
|
return o, nil
|
|
default:
|
|
return nil, object.ErrUnsupportedObject
|
|
}
|
|
}
|
|
|
|
// areStubsNeeded checks if all functions and methods defined in the stub file
|
|
// are present in the package
|
|
func (g *generator) areStubsNeeded() (needed bool, existingStubs map[string]struct{}, err error) {
|
|
pkg, err := parsePackage(".", []string{"enterprise"})
|
|
if err != nil {
|
|
return false, nil, err
|
|
}
|
|
|
|
stubFunctions := make(map[string]struct{})
|
|
for _, d := range g.file.Decls {
|
|
dFunc, ok := d.(*ast.FuncDecl)
|
|
if !ok {
|
|
continue
|
|
}
|
|
stubFunctions[dFunc.Name.Name] = struct{}{}
|
|
|
|
}
|
|
found := make(map[string]struct{})
|
|
for name, val := range pkg.TypesInfo.Defs {
|
|
if val == nil {
|
|
continue
|
|
}
|
|
_, ok := val.Type().(*types.Signature)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if _, ok := stubFunctions[name.Name]; ok {
|
|
found[name.Name] = struct{}{}
|
|
}
|
|
}
|
|
|
|
return len(found) != len(stubFunctions), found, nil
|
|
}
|
|
|
|
func (g *generator) modifyAST(exists map[string]struct{}) {
|
|
astutil.Apply(g.file, nil, func(c *astutil.Cursor) bool {
|
|
switch x := c.Node().(type) {
|
|
case *ast.FuncDecl:
|
|
if _, ok := exists[x.Name.Name]; ok {
|
|
c.Delete()
|
|
}
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
func parsePackage(name string, tags []string) (*packages.Package, error) {
|
|
cfg := &packages.Config{
|
|
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
|
Tests: false,
|
|
BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(tags, " "))},
|
|
}
|
|
pkgs, err := packages.Load(cfg, name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing package %s: %v", name, err)
|
|
}
|
|
if len(pkgs) != 1 {
|
|
return nil, fmt.Errorf("error: %d packages found", len(pkgs))
|
|
}
|
|
return pkgs[0], nil
|
|
}
|
|
|
|
func parseFile(buffer []byte) (*generator, error) {
|
|
fs := token.NewFileSet()
|
|
f, err := parser.ParseFile(fs, "", buffer, parser.AllErrors|parser.ParseComments)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &generator{
|
|
file: f,
|
|
fset: fs,
|
|
}, nil
|
|
}
|