vault/tools/stubmaker/main.go
Nick Cabatoff afc8f7d5aa
Prevent breakage when merging OSS changes to ENT due to missing funcs (#21564)
Add a go:generate helper called stubmaker, which generates appropriate stubs on ent based on oss stubs, but only when needed (i.e. real ent funcs haven't been added yet.)
2023-07-06 12:35:06 -04:00

258 lines
5.5 KiB
Go

package main
import (
"bufio"
"bytes"
"errors"
"fmt"
"go/types"
"io"
"os"
"path/filepath"
"strings"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/object"
"github.com/hashicorp/go-hclog"
"golang.org/x/tools/go/packages"
)
var logger hclog.Logger
func fatal(err error) {
logger.Error("fatal error", "error", err)
os.Exit(1)
}
func main() {
logger = hclog.New(&hclog.LoggerOptions{
Name: "stubmaker",
Level: hclog.Trace,
})
repo, err := git.PlainOpenWithOptions(".", &git.PlainOpenOptions{
DetectDotGit: true,
})
if err != nil {
fatal(err)
}
wt, err := repo.Worktree()
if err != nil {
fatal(err)
}
if !isEnterprise(wt) {
return
}
head, err := repo.Head()
if err != nil {
fatal(err)
}
obj, err := repo.Object(plumbing.AnyObject, head.Hash())
if err != nil {
fatal(err)
}
st, err := wt.Status()
if err != nil {
fatal(err)
}
inputFile := os.Getenv("GOFILE")
if !strings.HasSuffix(inputFile, "_oss.go") {
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _oss.go"))
}
baseFilename := strings.TrimSuffix(inputFile, "_oss.go")
target := baseFilename + "_ent.go"
tracked, err := inGit(wt, st, obj, target)
if err != nil {
fatal(err)
}
if tracked {
fatal(fmt.Errorf("output file %s exists in git, not overwriting", target))
}
if err := writeStubIfNeeded(inputFile, target); err != nil {
fatal(err)
}
}
func inGit(wt *git.Worktree, st git.Status, obj object.Object, path string) (bool, error) {
absPath, err := filepath.Abs(path)
if err != nil {
return false, fmt.Errorf("path %s can't be made absolute: %w", path, err)
}
relPath, err := filepath.Rel(wt.Filesystem.Root(), absPath)
if err != nil {
return false, fmt.Errorf("path %s can't be made relative: %w", absPath, err)
}
fst := st.File(relPath)
if fst.Worktree != git.Untracked || fst.Staging != git.Untracked {
return true, nil
}
curwd, err := os.Getwd()
if err != nil {
return false, err
}
blob, err := resolve(obj, relPath)
if err != nil && !strings.Contains(err.Error(), "file not found") {
return false, fmt.Errorf("error resolving path %s from %s: %w", relPath, curwd, err)
}
return blob != nil, nil
}
func isEnterprise(wt *git.Worktree) bool {
st, err := wt.Filesystem.Stat("enthelpers")
onOss := errors.Is(err, os.ErrNotExist)
onEnt := st != nil
switch {
case onOss && !onEnt:
case !onOss && onEnt:
default:
fatal(err)
}
return onEnt
}
// resolve blob at given path from obj. obj can be a commit, tag, tree, or blob.
func resolve(obj object.Object, path string) (*object.Blob, error) {
switch o := obj.(type) {
case *object.Commit:
t, err := o.Tree()
if err != nil {
return nil, err
}
return resolve(t, path)
case *object.Tag:
target, err := o.Object()
if err != nil {
return nil, err
}
return resolve(target, path)
case *object.Tree:
file, err := o.File(path)
if err != nil {
return nil, err
}
return &file.Blob, nil
case *object.Blob:
return o, nil
default:
return nil, object.ErrUnsupportedObject
}
}
func readLines(r io.Reader) ([]string, error) {
scanner := bufio.NewScanner(r)
scanner.Split(bufio.ScanLines)
var lines []string
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
return nil, err
}
return lines, nil
}
func writeStubIfNeeded(inputFile, outputFile string) (err error) {
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
var output *os.File
b, err := os.ReadFile(inputFile)
if err != nil {
return err
}
inputLines, err := readLines(bytes.NewBuffer(b))
var funcs []string
var outputLines []string
for _, line := range inputLines {
switch line {
case "//go:build !enterprise":
outputLines = append(outputLines, warning, "")
line = "//go:build enterprise"
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
continue
}
outputLines = append(outputLines, line)
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "func ") {
i := strings.Index(trimmed, "(")
if i != -1 {
funcs = append(funcs, trimmed[5:i])
}
}
}
pkg, err := parsePackage(".", []string{"enterprise"})
if err != nil {
return err
}
var found []string
for name, val := range pkg.TypesInfo.Defs {
if val == nil {
continue
}
_, ok := val.Type().(*types.Signature)
if !ok {
continue
}
for _, f := range funcs {
if name.Name == f {
found = append(found, f)
}
}
}
switch {
case len(found) == len(funcs):
return nil
case len(found) != 0:
return fmt.Errorf("funcs partially defined: need=%v, found=%v", funcs, found)
}
output, err = os.Create(outputFile + ".tmp")
if err != nil {
return err
}
// If we don't end up writing to the file, delete it.
defer func() {
if err != nil {
os.Remove(outputFile + ".tmp")
} else {
os.Rename(outputFile+".tmp", outputFile)
}
}()
_, err = io.WriteString(output, strings.Join(outputLines, "\n")+"\n")
return err
}
func parsePackage(name string, tags []string) (*packages.Package, error) {
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
Tests: false,
BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(tags, " "))},
}
pkgs, err := packages.Load(cfg, name)
if err != nil {
return nil, fmt.Errorf("error parsing package %s: %v", name, err)
}
if len(pkgs) != 1 {
return nil, fmt.Errorf("error: %d packages found", len(pkgs))
}
return pkgs[0], nil
}