mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-17 12:07:02 +02:00
Add a go:generate helper called stubmaker, which generates appropriate stubs on ent based on oss stubs, but only when needed (i.e. real ent funcs haven't been added yet.)
258 lines
5.5 KiB
Go
258 lines
5.5 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"go/types"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/go-git/go-git/v5"
|
|
"github.com/go-git/go-git/v5/plumbing"
|
|
"github.com/go-git/go-git/v5/plumbing/object"
|
|
"github.com/hashicorp/go-hclog"
|
|
"golang.org/x/tools/go/packages"
|
|
)
|
|
|
|
var logger hclog.Logger
|
|
|
|
func fatal(err error) {
|
|
logger.Error("fatal error", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
func main() {
|
|
logger = hclog.New(&hclog.LoggerOptions{
|
|
Name: "stubmaker",
|
|
Level: hclog.Trace,
|
|
})
|
|
|
|
repo, err := git.PlainOpenWithOptions(".", &git.PlainOpenOptions{
|
|
DetectDotGit: true,
|
|
})
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
wt, err := repo.Worktree()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
if !isEnterprise(wt) {
|
|
return
|
|
}
|
|
|
|
head, err := repo.Head()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
obj, err := repo.Object(plumbing.AnyObject, head.Hash())
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
st, err := wt.Status()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
|
|
inputFile := os.Getenv("GOFILE")
|
|
if !strings.HasSuffix(inputFile, "_oss.go") {
|
|
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _oss.go"))
|
|
}
|
|
|
|
baseFilename := strings.TrimSuffix(inputFile, "_oss.go")
|
|
target := baseFilename + "_ent.go"
|
|
|
|
tracked, err := inGit(wt, st, obj, target)
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
if tracked {
|
|
fatal(fmt.Errorf("output file %s exists in git, not overwriting", target))
|
|
}
|
|
|
|
if err := writeStubIfNeeded(inputFile, target); err != nil {
|
|
fatal(err)
|
|
}
|
|
}
|
|
|
|
func inGit(wt *git.Worktree, st git.Status, obj object.Object, path string) (bool, error) {
|
|
absPath, err := filepath.Abs(path)
|
|
if err != nil {
|
|
return false, fmt.Errorf("path %s can't be made absolute: %w", path, err)
|
|
}
|
|
relPath, err := filepath.Rel(wt.Filesystem.Root(), absPath)
|
|
if err != nil {
|
|
return false, fmt.Errorf("path %s can't be made relative: %w", absPath, err)
|
|
}
|
|
|
|
fst := st.File(relPath)
|
|
if fst.Worktree != git.Untracked || fst.Staging != git.Untracked {
|
|
return true, nil
|
|
}
|
|
|
|
curwd, err := os.Getwd()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
blob, err := resolve(obj, relPath)
|
|
if err != nil && !strings.Contains(err.Error(), "file not found") {
|
|
return false, fmt.Errorf("error resolving path %s from %s: %w", relPath, curwd, err)
|
|
}
|
|
|
|
return blob != nil, nil
|
|
}
|
|
|
|
func isEnterprise(wt *git.Worktree) bool {
|
|
st, err := wt.Filesystem.Stat("enthelpers")
|
|
onOss := errors.Is(err, os.ErrNotExist)
|
|
onEnt := st != nil
|
|
|
|
switch {
|
|
case onOss && !onEnt:
|
|
case !onOss && onEnt:
|
|
default:
|
|
fatal(err)
|
|
}
|
|
return onEnt
|
|
}
|
|
|
|
// resolve blob at given path from obj. obj can be a commit, tag, tree, or blob.
|
|
func resolve(obj object.Object, path string) (*object.Blob, error) {
|
|
switch o := obj.(type) {
|
|
case *object.Commit:
|
|
t, err := o.Tree()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resolve(t, path)
|
|
case *object.Tag:
|
|
target, err := o.Object()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resolve(target, path)
|
|
case *object.Tree:
|
|
file, err := o.File(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &file.Blob, nil
|
|
case *object.Blob:
|
|
return o, nil
|
|
default:
|
|
return nil, object.ErrUnsupportedObject
|
|
}
|
|
}
|
|
|
|
func readLines(r io.Reader) ([]string, error) {
|
|
scanner := bufio.NewScanner(r)
|
|
scanner.Split(bufio.ScanLines)
|
|
|
|
var lines []string
|
|
for scanner.Scan() {
|
|
lines = append(lines, scanner.Text())
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return lines, nil
|
|
}
|
|
|
|
func writeStubIfNeeded(inputFile, outputFile string) (err error) {
|
|
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
|
|
|
|
var output *os.File
|
|
b, err := os.ReadFile(inputFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
inputLines, err := readLines(bytes.NewBuffer(b))
|
|
var funcs []string
|
|
var outputLines []string
|
|
for _, line := range inputLines {
|
|
switch line {
|
|
case "//go:build !enterprise":
|
|
outputLines = append(outputLines, warning, "")
|
|
line = "//go:build enterprise"
|
|
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
|
|
continue
|
|
}
|
|
outputLines = append(outputLines, line)
|
|
|
|
trimmed := strings.TrimSpace(line)
|
|
if strings.HasPrefix(trimmed, "func ") {
|
|
i := strings.Index(trimmed, "(")
|
|
if i != -1 {
|
|
funcs = append(funcs, trimmed[5:i])
|
|
}
|
|
}
|
|
}
|
|
|
|
pkg, err := parsePackage(".", []string{"enterprise"})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var found []string
|
|
for name, val := range pkg.TypesInfo.Defs {
|
|
if val == nil {
|
|
continue
|
|
}
|
|
_, ok := val.Type().(*types.Signature)
|
|
if !ok {
|
|
continue
|
|
}
|
|
for _, f := range funcs {
|
|
if name.Name == f {
|
|
found = append(found, f)
|
|
}
|
|
}
|
|
}
|
|
switch {
|
|
case len(found) == len(funcs):
|
|
return nil
|
|
case len(found) != 0:
|
|
return fmt.Errorf("funcs partially defined: need=%v, found=%v", funcs, found)
|
|
}
|
|
|
|
output, err = os.Create(outputFile + ".tmp")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// If we don't end up writing to the file, delete it.
|
|
defer func() {
|
|
if err != nil {
|
|
os.Remove(outputFile + ".tmp")
|
|
} else {
|
|
os.Rename(outputFile+".tmp", outputFile)
|
|
}
|
|
}()
|
|
|
|
_, err = io.WriteString(output, strings.Join(outputLines, "\n")+"\n")
|
|
return err
|
|
}
|
|
|
|
func parsePackage(name string, tags []string) (*packages.Package, error) {
|
|
cfg := &packages.Config{
|
|
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
|
Tests: false,
|
|
BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(tags, " "))},
|
|
}
|
|
pkgs, err := packages.Load(cfg, name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing package %s: %v", name, err)
|
|
}
|
|
if len(pkgs) != 1 {
|
|
return nil, fmt.Errorf("error: %d packages found", len(pkgs))
|
|
}
|
|
return pkgs[0], nil
|
|
}
|