mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-26 05:41:04 +01:00 
			
		
		
		
	We have several checked type assertions to *types.Named in both cmd/cloner and cmd/viewer. As Go 1.23 updates the go/types package to produce Alias type nodes for type aliases, these type assertions no longer work as expected unless the new behavior is disabled with gotypesalias=0. In this PR, we add codegen.NamedTypeOf(t types.Type), which functions like t.(*types.Named) but also unrolls type aliases. We then use it in place of type assertions in the cmd/cloner and cmd/viewer packages where appropriate. We also update type switches to include *types.Alias alongside *types.Named in relevant cases, remove *types.Struct cases when switching on types.Type.Underlying and update the tests with more cases where type aliases can be used. Updates #13224 Updates #12912 Signed-off-by: Nick Khyl <nickk@tailscale.com>
		
			
				
	
	
		
			394 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			394 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| // Package codegen contains shared utilities for generating code.
 | |
| package codegen
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"flag"
 | |
| 	"fmt"
 | |
| 	"go/ast"
 | |
| 	"go/token"
 | |
| 	"go/types"
 | |
| 	"io"
 | |
| 	"os"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	"golang.org/x/tools/go/packages"
 | |
| 	"golang.org/x/tools/imports"
 | |
| 	"tailscale.com/util/mak"
 | |
| )
 | |
| 
 | |
| var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers")
 | |
| 
 | |
| // LoadTypes returns all named types in pkgName, keyed by their type name.
 | |
| func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]types.Type, error) {
 | |
| 	cfg := &packages.Config{
 | |
| 		Mode:  packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
 | |
| 		Tests: buildTags == "test",
 | |
| 	}
 | |
| 	if buildTags != "" && !cfg.Tests {
 | |
| 		cfg.BuildFlags = []string{"-tags=" + buildTags}
 | |
| 	}
 | |
| 
 | |
| 	pkgs, err := packages.Load(cfg, pkgName)
 | |
| 	if err != nil {
 | |
| 		return nil, nil, err
 | |
| 	}
 | |
| 	if cfg.Tests {
 | |
| 		pkgs = testPackages(pkgs)
 | |
| 	}
 | |
| 	if len(pkgs) != 1 {
 | |
| 		return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
 | |
| 	}
 | |
| 	pkg := pkgs[0]
 | |
| 	return pkg, namedTypes(pkg), nil
 | |
| }
 | |
| 
 | |
| func testPackages(pkgs []*packages.Package) []*packages.Package {
 | |
| 	var testPackages []*packages.Package
 | |
| 	for _, pkg := range pkgs {
 | |
| 		testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
 | |
| 		if pkg.ID == testPackageID {
 | |
| 			testPackages = append(testPackages, pkg)
 | |
| 		}
 | |
| 	}
 | |
| 	return testPackages
 | |
| }
 | |
| 
 | |
| // HasNoClone reports whether the provided tag has `codegen:noclone`.
 | |
| func HasNoClone(structTag string) bool {
 | |
| 	val := reflect.StructTag(structTag).Get("codegen")
 | |
| 	for _, v := range strings.Split(val, ",") {
 | |
| 		if v == "noclone" {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| const copyrightHeader = `// Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| `
 | |
| 
 | |
| const genAndPackageHeader = `// Code generated by %v; DO NOT EDIT.
 | |
| 
 | |
| package %s
 | |
| `
 | |
| 
 | |
| func NewImportTracker(thisPkg *types.Package) *ImportTracker {
 | |
| 	return &ImportTracker{
 | |
| 		thisPkg: thisPkg,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ImportTracker provides a mechanism to track and build import paths.
 | |
| type ImportTracker struct {
 | |
| 	thisPkg  *types.Package
 | |
| 	packages map[string]bool
 | |
| }
 | |
| 
 | |
| func (it *ImportTracker) Import(pkg string) {
 | |
| 	if pkg != "" && !it.packages[pkg] {
 | |
| 		mak.Set(&it.packages, pkg, true)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (it *ImportTracker) qualifier(pkg *types.Package) string {
 | |
| 	if it.thisPkg == pkg {
 | |
| 		return ""
 | |
| 	}
 | |
| 	it.Import(pkg.Path())
 | |
| 	// TODO(maisem): handle conflicts?
 | |
| 	return pkg.Name()
 | |
| }
 | |
| 
 | |
| // QualifiedName returns the string representation of t in the package.
 | |
| func (it *ImportTracker) QualifiedName(t types.Type) string {
 | |
| 	return types.TypeString(t, it.qualifier)
 | |
| }
 | |
| 
 | |
| // PackagePrefix returns the prefix to be used when referencing named objects from pkg.
 | |
| func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
 | |
| 	if s := it.qualifier(pkg); s != "" {
 | |
| 		return s + "."
 | |
| 	}
 | |
| 	return ""
 | |
| }
 | |
| 
 | |
| // Write prints all the tracked imports in a single import block to w.
 | |
| func (it *ImportTracker) Write(w io.Writer) {
 | |
| 	fmt.Fprintf(w, "import (\n")
 | |
| 	for s := range it.packages {
 | |
| 		fmt.Fprintf(w, "\t%q\n", s)
 | |
| 	}
 | |
| 	fmt.Fprintf(w, ")\n\n")
 | |
| }
 | |
| 
 | |
| func writeHeader(w io.Writer, tool, pkg string) {
 | |
| 	if *flagCopyright {
 | |
| 		fmt.Fprint(w, copyrightHeader)
 | |
| 	}
 | |
| 	fmt.Fprintf(w, genAndPackageHeader, tool, pkg)
 | |
| }
 | |
| 
 | |
| // WritePackageFile adds a file with the provided imports and contents to package.
 | |
| // The tool param is used to identify the tool that generated package file.
 | |
| func WritePackageFile(tool string, pkg *packages.Package, path string, it *ImportTracker, contents *bytes.Buffer) error {
 | |
| 	buf := new(bytes.Buffer)
 | |
| 	writeHeader(buf, tool, pkg.Name)
 | |
| 	it.Write(buf)
 | |
| 	if _, err := buf.Write(contents.Bytes()); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return writeFormatted(buf.Bytes(), path)
 | |
| }
 | |
| 
 | |
| // writeFormatted writes code to path.
 | |
| // It runs gofmt on it before writing;
 | |
| // if gofmt fails, it writes code unchanged.
 | |
| // Errors can include I/O errors and gofmt errors.
 | |
| //
 | |
| // The advantage of always writing code to path,
 | |
| // even if gofmt fails, is that it makes debugging easier.
 | |
| // The code can be long, but you need it in order to debug.
 | |
| // It is nicer to work with it in a file than a terminal.
 | |
| // It is also easier to interpret gofmt errors
 | |
| // with an editor providing file and line numbers.
 | |
| func writeFormatted(code []byte, path string) error {
 | |
| 	out, fmterr := imports.Process(path, code, &imports.Options{
 | |
| 		Comments:   true,
 | |
| 		TabIndent:  true,
 | |
| 		TabWidth:   8,
 | |
| 		FormatOnly: true, // fancy gofmt only
 | |
| 	})
 | |
| 	if fmterr != nil {
 | |
| 		out = code
 | |
| 	}
 | |
| 	ioerr := os.WriteFile(path, out, 0644)
 | |
| 	// Prefer I/O errors. They're usually easier to fix,
 | |
| 	// and until they're fixed you can't do much else.
 | |
| 	if ioerr != nil {
 | |
| 		return ioerr
 | |
| 	}
 | |
| 	if fmterr != nil {
 | |
| 		return fmt.Errorf("%s:%v", path, fmterr)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // namedTypes returns all named types in pkg, keyed by their type name.
 | |
| func namedTypes(pkg *packages.Package) map[string]types.Type {
 | |
| 	nt := make(map[string]types.Type)
 | |
| 	for _, file := range pkg.Syntax {
 | |
| 		for _, d := range file.Decls {
 | |
| 			decl, ok := d.(*ast.GenDecl)
 | |
| 			if !ok || decl.Tok != token.TYPE {
 | |
| 				continue
 | |
| 			}
 | |
| 			for _, s := range decl.Specs {
 | |
| 				spec, ok := s.(*ast.TypeSpec)
 | |
| 				if !ok {
 | |
| 					continue
 | |
| 				}
 | |
| 				typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
 | |
| 				if !ok {
 | |
| 					continue
 | |
| 				}
 | |
| 				switch typ := typeNameObj.Type(); typ.(type) {
 | |
| 				case *types.Alias, *types.Named:
 | |
| 					nt[spec.Name.Name] = typ
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return nt
 | |
| }
 | |
| 
 | |
| // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
 | |
| // thisPkg is the package containing t.
 | |
| // tname is the named type corresponding to t.
 | |
| // ctx is a single-word context for this assertion, such as "Clone".
 | |
| // If non-nil, AssertStructUnchanged will add elements to imports
 | |
| // for each package path that the caller must import for the returned code to compile.
 | |
| func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
 | |
| 	buf := new(bytes.Buffer)
 | |
| 	w := func(format string, args ...any) {
 | |
| 		fmt.Fprintf(buf, format+"\n", args...)
 | |
| 	}
 | |
| 	w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
 | |
| 
 | |
| 	hasTypeParams := params != nil && params.Len() > 0
 | |
| 	if hasTypeParams {
 | |
| 		constraints, identifiers := FormatTypeParams(params, it)
 | |
| 		w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
 | |
| 		w("_%s%sNeedsRegeneration(struct {", tname, ctx)
 | |
| 	} else {
 | |
| 		w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
 | |
| 	}
 | |
| 
 | |
| 	for i := range t.NumFields() {
 | |
| 		st := t.Field(i)
 | |
| 		fname := st.Name()
 | |
| 		ft := t.Field(i).Type()
 | |
| 		if IsInvalid(ft) {
 | |
| 			continue
 | |
| 		}
 | |
| 		qname := it.QualifiedName(ft)
 | |
| 		var tag string
 | |
| 		if hasTypeParams {
 | |
| 			tag = t.Tag(i)
 | |
| 			if tag != "" {
 | |
| 				tag = "`" + tag + "`"
 | |
| 			}
 | |
| 		}
 | |
| 		if st.Anonymous() {
 | |
| 			w("\t%s %s", qname, tag)
 | |
| 		} else {
 | |
| 			w("\t%s %s %s", fname, qname, tag)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if hasTypeParams {
 | |
| 		w("}{})\n}")
 | |
| 	} else {
 | |
| 		w("}{})")
 | |
| 	}
 | |
| 	return buf.Bytes()
 | |
| }
 | |
| 
 | |
| // IsInvalid reports whether the provided type is invalid. It is used to allow
 | |
| // codegeneration to run even when the target files have build errors or are
 | |
| // missing views.
 | |
| func IsInvalid(t types.Type) bool {
 | |
| 	return t.String() == "invalid type"
 | |
| }
 | |
| 
 | |
| // ContainsPointers reports whether typ contains any pointers,
 | |
| // either explicitly or implicitly.
 | |
| // It has special handling for some types that contain pointers
 | |
| // that we know are free from memory aliasing/mutation concerns.
 | |
| func ContainsPointers(typ types.Type) bool {
 | |
| 	switch typ.String() {
 | |
| 	case "time.Time":
 | |
| 		// time.Time contains a pointer that does not need copying
 | |
| 		return false
 | |
| 	case "inet.af/netip.Addr", "net/netip.Addr", "net/netip.Prefix", "net/netip.AddrPort":
 | |
| 		return false
 | |
| 	}
 | |
| 	switch ft := typ.Underlying().(type) {
 | |
| 	case *types.Array:
 | |
| 		return ContainsPointers(ft.Elem())
 | |
| 	case *types.Basic:
 | |
| 		if ft.Kind() == types.UnsafePointer {
 | |
| 			return true
 | |
| 		}
 | |
| 	case *types.Chan:
 | |
| 		return true
 | |
| 	case *types.Interface:
 | |
| 		if ft.Empty() || ft.IsMethodSet() {
 | |
| 			return true
 | |
| 		}
 | |
| 		for i := 0; i < ft.NumEmbeddeds(); i++ {
 | |
| 			if ContainsPointers(ft.EmbeddedType(i)) {
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 	case *types.Map:
 | |
| 		return true
 | |
| 	case *types.Pointer:
 | |
| 		return true
 | |
| 	case *types.Slice:
 | |
| 		return true
 | |
| 	case *types.Struct:
 | |
| 		for i := range ft.NumFields() {
 | |
| 			if ContainsPointers(ft.Field(i).Type()) {
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 	case *types.Union:
 | |
| 		for i := range ft.Len() {
 | |
| 			if ContainsPointers(ft.Term(i).Type()) {
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // IsViewType reports whether the provided typ is a View.
 | |
| func IsViewType(typ types.Type) bool {
 | |
| 	t, ok := typ.Underlying().(*types.Struct)
 | |
| 	if !ok {
 | |
| 		return false
 | |
| 	}
 | |
| 	if t.NumFields() != 1 {
 | |
| 		return false
 | |
| 	}
 | |
| 	return t.Field(0).Name() == "ж"
 | |
| }
 | |
| 
 | |
| // FormatTypeParams formats the specified params and returns two strings:
 | |
| //   - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
 | |
| //   - names are comma-separated type parameter names in square brackets (e.g. [T, V])
 | |
| //
 | |
| // If params is nil or empty, both return values are empty strings.
 | |
| func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
 | |
| 	if params == nil || params.Len() == 0 {
 | |
| 		return "", ""
 | |
| 	}
 | |
| 	var constraintList, nameList []string
 | |
| 	for i := range params.Len() {
 | |
| 		param := params.At(i)
 | |
| 		name := param.Obj().Name()
 | |
| 		constraint := it.QualifiedName(param.Constraint())
 | |
| 		nameList = append(nameList, name)
 | |
| 		constraintList = append(constraintList, name+" "+constraint)
 | |
| 	}
 | |
| 	constraints = "[" + strings.Join(constraintList, ", ") + "]"
 | |
| 	names = "[" + strings.Join(nameList, ", ") + "]"
 | |
| 	return constraints, names
 | |
| }
 | |
| 
 | |
| // LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
 | |
| func LookupMethod(t types.Type, name string) *types.Func {
 | |
| 	switch t := t.(type) {
 | |
| 	case *types.Alias:
 | |
| 		return LookupMethod(t.Rhs(), name)
 | |
| 	case *types.TypeParam:
 | |
| 		return LookupMethod(t.Constraint(), name)
 | |
| 	case *types.Pointer:
 | |
| 		return LookupMethod(t.Elem(), name)
 | |
| 	case *types.Named:
 | |
| 		switch u := t.Underlying().(type) {
 | |
| 		case *types.Interface:
 | |
| 			return LookupMethod(u, name)
 | |
| 		default:
 | |
| 			for i := 0; i < t.NumMethods(); i++ {
 | |
| 				if method := t.Method(i); method.Name() == name {
 | |
| 					return method
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	case *types.Interface:
 | |
| 		for i := 0; i < t.NumMethods(); i++ {
 | |
| 			if method := t.Method(i); method.Name() == name {
 | |
| 				return method
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // NamedTypeOf is like t.(*types.Named), but also works with type aliases.
 | |
| func NamedTypeOf(t types.Type) (named *types.Named, ok bool) {
 | |
| 	if a, ok := t.(*types.Alias); ok {
 | |
| 		return NamedTypeOf(types.Unalias(a))
 | |
| 	}
 | |
| 	named, ok = t.(*types.Named)
 | |
| 	return
 | |
| }
 |