tailscale/cmd/viewer/viewer.go
Josh Bleecher Snyder c7b7546587 WIP snapshot
Next up: view support for maps, etc.
2021-09-17 16:47:00 -07:00

217 lines
6.5 KiB
Go

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Viewer is a tool to automate the creation of a view type.
//
// The generated View method provides a readonly view of the struct.
//
// This tool makes lots of implicit assumptions about the types you feed it.
// In particular, it can only write relatively "shallow" View methods.
// That is, if a type contains another named struct type, viewer assumes that
// named type will also have a View method.
package main
import (
"bytes"
"flag"
"fmt"
"go/types"
"log"
"os"
"strings"
"golang.org/x/tools/go/packages"
"tailscale.com/util/codegen"
)
var (
flagTypes = flag.String("type", "", "comma-separated list of types; required")
flagOutput = flag.String("output", "", "output file; required")
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
)
func main() {
log.SetFlags(0)
log.SetPrefix("viewer: ")
flag.Parse()
if len(*flagTypes) == 0 {
flag.Usage()
os.Exit(2)
}
typeNames := strings.Split(*flagTypes, ",")
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
Tests: false,
}
if *flagBuildTags != "" {
cfg.BuildFlags = []string{"-tags=" + *flagBuildTags}
}
pkgs, err := packages.Load(cfg, ".")
if err != nil {
log.Fatal(err)
}
if len(pkgs) != 1 {
log.Fatalf("wrong number of packages: %d", len(pkgs))
}
pkg := pkgs[0]
buf := new(bytes.Buffer)
imports := make(map[string]struct{})
namedTypes := codegen.NamedTypes(pkg)
for _, typeName := range typeNames {
typ, ok := namedTypes[typeName]
if !ok {
log.Fatalf("could not find type %s", typeName)
}
gen(buf, imports, typ, pkg.Types)
}
contents := new(bytes.Buffer)
fmt.Fprintf(contents, header, *flagTypes, pkg.Name)
fmt.Fprintf(contents, "import (\n")
for s := range imports {
fmt.Fprintf(contents, "\t%q\n", s)
}
fmt.Fprintf(contents, ")\n\n")
contents.Write(buf.Bytes())
output := *flagOutput
if output == "" {
flag.Usage()
os.Exit(2)
}
if err := codegen.WriteFormatted(contents.Bytes(), output); err != nil {
log.Fatal(err)
}
}
const header = `// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Code generated by the following command; DO NOT EDIT.
// tailscale.com/cmd/viewer -type %s
package %s
`
func gen(buf *bytes.Buffer, imports map[string]struct{}, typ *types.Named, thisPkg *types.Package) {
pkgQual := func(pkg *types.Package) string {
if thisPkg == pkg {
return ""
}
imports[pkg.Path()] = struct{}{}
return pkg.Name()
}
importedName := func(t types.Type) string {
return types.TypeString(t, pkgQual)
}
t, ok := typ.Underlying().(*types.Struct)
if !ok {
return
}
name := typ.Obj().Name()
viewName := name + "View"
fmt.Fprintf(buf, "// View makes a readonly view of %s.\n", name)
fmt.Fprintf(buf, "func (src *%s) View() %s {\n", name, viewName)
fmt.Fprintf(buf, " return %s{src}\n", viewName)
fmt.Fprintf(buf, "}\n")
fmt.Fprintf(buf, "// %s is a readonly view of %s.\n", viewName, name)
fmt.Fprintf(buf, "type %s struct{ ж *%s }\n", viewName, name)
fmt.Fprintf(buf, "func (v %s) Valid() bool { return v.ж != nil }\n", viewName)
for i := 0; i < t.NumFields(); i++ {
fname := t.Field(i).Name()
ft := t.Field(i).Type()
if !codegen.ContainsPointers(ft) {
fmt.Fprintf(buf, "func (v %s) %s() %s { return v.ж.%s }\n", viewName, fname, importedName(ft), fname)
continue
}
if named, _ := ft.(*types.Named); named != nil && !hasBasicUnderlying(ft) {
genViewCall(buf, viewName, fname, importedName(ft))
continue
}
switch ft := ft.Underlying().(type) {
case *types.Slice:
if !codegen.ContainsPointers(ft.Elem()) {
// OK to return the slice as-is, since they can't modify the contents.
fmt.Fprintf(buf, "func (v %s) %s() %s { return v.ж.%s }\n", viewName, fname, importedName(ft), fname)
continue
}
n := importedName(ft.Elem())
if ptrTyp, isPtr := ft.Elem().(*types.Pointer); isPtr {
n = importedName(ptrTyp.Elem())
}
// Generate slice view.
styp := fmt.Sprintf("_%s_%s", viewName, fname)
fmt.Fprintf(buf, "type %s []%s\n", styp, importedName(ft.Elem()))
fmt.Fprintf(buf, "func (s %s) Len() int { return len(s) }\n", styp)
fmt.Fprintf(buf, "func (s %s) At(i int) %sView { return s[i].View() }\n", styp, n)
fmt.Fprintf(buf, "func (v %s) %s() interface { Len() int; At(int) %sView } {\n", viewName, fname, n)
fmt.Fprintf(buf, " return %s(v.ж.%s)\n", styp, fname)
fmt.Fprintf(buf, "}\n")
case *types.Pointer:
if named, _ := ft.Elem().(*types.Named); named != nil && codegen.ContainsPointers(ft.Elem()) {
genViewCall(buf, viewName, fname, importedName(named))
continue
}
if codegen.ContainsPointers(ft.Elem()) {
log.Fatalf("unhandled: pointers in pointers (%v)", ft)
}
n := importedName(ft.Elem())
fmt.Fprintf(buf, "func (v %s) %s() *%s {\n", viewName, fname, n)
fmt.Fprintf(buf, " ptr := v.ж.%s\n", fname)
fmt.Fprintf(buf, " if ptr == nil {\n")
fmt.Fprintf(buf, " return nil\n")
fmt.Fprintf(buf, " }\n")
fmt.Fprintf(buf, " cp := *ptr\n")
fmt.Fprintf(buf, " return &cp\n")
fmt.Fprintf(buf, "}\n")
case *types.Map:
// TODO: Generate map view, like the slice view.
// We need:
// * Len() int
// * Load(k) v
// * LoadOK(k) (v, bool)
// * Range(func(k, v) bool)
//
// Note that we need to handle a variety of elem types:
// basic types (float64), types with a View method,
// slices of the foregoing.
//
// This may require recursion to handle completely,
// or we can follow cloner's lead and just manually
// inline one level deep the code generation
// that we happen to need right now.
// (If we figure out recursion in this context,
// we might want to backport to cloner, too.)
log.Printf("TODO: Handle %s (%s)", name, ft)
default:
fmt.Fprintf(buf, `panic("TODO: %s (%T)")`, fname, ft)
}
}
buf.Write(codegen.AssertStructUnchanged(t, thisPkg, name, "View", imports))
}
func genViewCall(buf *bytes.Buffer, viewName, fieldName, importedName string) {
fmt.Fprintf(buf, "func (v %s) %s() %sView { return v.ж.%s.View() }\n", viewName, fieldName, importedName, fieldName)
}
func hasBasicUnderlying(typ types.Type) bool {
switch typ.Underlying().(type) {
case *types.Slice, *types.Map:
return true
default:
return false
}
}