mirror of
https://github.com/tailscale/tailscale.git
synced 2026-04-14 18:11:20 +02:00
The cloner and viewer code generators didn't handle named types
with basic underlying types (map/slice) that have their own Clone
or View methods. For example, a type like:
type Map map[string]any
func (m Map) Clone() Map { ... }
func (m Map) View() MapView { ... }
When used as a struct field, the cloner would descend into the
underlying map[string]any and fail because it can't clone the any
(interface{}) value type. Similarly, the viewer would try to create
a MapFnOf view and fail.
Fix the cloner to check for a Clone method on the named type
before falling through to the underlying type handling.
Fix the viewer to check for a View method on named map/slice types,
so the type author can provide a purpose-built safe view that
doesn't leak raw any values. Named map/slice types without a View
method fall through to normal handling, which correctly rejects
types like map[string]any as unsupported.
Updates tailscale/corp#39502 (needed by tailscale/corp#39594)
Change-Id: Iaef0192a221e02b4b8e409c99ef8398090327744
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
178 lines
4.6 KiB
Go
178 lines
4.6 KiB
Go
// Copyright (c) Tailscale Inc & contributors
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"go/types"
|
|
"strings"
|
|
"testing"
|
|
|
|
"tailscale.com/util/codegen"
|
|
)
|
|
|
|
// TestNamedMapWithView tests that a named map type with a user-supplied
|
|
// View() method causes the generated view accessor to call .View() and
|
|
// return the user-defined view type. Without the View() method the
|
|
// generator should reject the field as unsupported.
|
|
func TestNamedMapWithView(t *testing.T) {
|
|
const src = `
|
|
package test
|
|
|
|
// AttrMap is a named map whose values are opaque (any).
|
|
// It provides its own Clone and View methods.
|
|
type AttrMap map[string]any
|
|
|
|
func (m AttrMap) Clone() AttrMap {
|
|
m2 := make(AttrMap, len(m))
|
|
for k, v := range m { m2[k] = v }
|
|
return m2
|
|
}
|
|
|
|
// AttrMapView is a hand-written read-only view of AttrMap.
|
|
type AttrMapView struct{ m AttrMap }
|
|
|
|
func (m AttrMap) View() AttrMapView { return AttrMapView{m} }
|
|
|
|
// Container holds an AttrMap field.
|
|
type Container struct {
|
|
Attrs AttrMap
|
|
}
|
|
`
|
|
output := genViewOutput(t, src, "Container")
|
|
|
|
// The generated accessor must call .View() and return the
|
|
// user-defined AttrMapView, not views.Map or the raw AttrMap.
|
|
const want = "func (v ContainerView) Attrs() AttrMapView { return v.ж.Attrs.View() }"
|
|
if !strings.Contains(output, want) {
|
|
t.Errorf("generated output missing expected accessor\nwant: %s\ngot:\n%s", want, output)
|
|
}
|
|
}
|
|
|
|
// TestNamedMapWithoutView tests that a named map[string]any WITHOUT a
|
|
// View() method does NOT generate an accessor that calls .View().
|
|
func TestNamedMapWithoutView(t *testing.T) {
|
|
const src = `
|
|
package test
|
|
|
|
type AttrMap map[string]any
|
|
|
|
func (m AttrMap) Clone() AttrMap {
|
|
m2 := make(AttrMap, len(m))
|
|
for k, v := range m { m2[k] = v }
|
|
return m2
|
|
}
|
|
|
|
type Container struct {
|
|
Attrs AttrMap
|
|
}
|
|
`
|
|
output := genViewOutput(t, src, "Container")
|
|
|
|
// Must not generate an accessor that calls .Attrs.View(),
|
|
// since AttrMap doesn't have a View() method.
|
|
if strings.Contains(output, "Attrs.View()") {
|
|
t.Errorf("generated code calls .Attrs.View() but AttrMap has no View method:\n%s", output)
|
|
}
|
|
// Must not return AttrMapView (which doesn't exist).
|
|
if strings.Contains(output, "AttrMapView") {
|
|
t.Errorf("generated code references AttrMapView but it doesn't exist:\n%s", output)
|
|
}
|
|
}
|
|
|
|
// genViewOutput parses src, runs genView on the named type, and returns
|
|
// the generated Go source.
|
|
func genViewOutput(t *testing.T, src string, typeName string) string {
|
|
t.Helper()
|
|
fset := token.NewFileSet()
|
|
f, err := parser.ParseFile(fset, "test.go", src, 0)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
conf := types.Config{}
|
|
pkg, err := conf.Check("test", fset, []*ast.File{f}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
obj := pkg.Scope().Lookup(typeName)
|
|
if obj == nil {
|
|
t.Fatalf("type %q not found", typeName)
|
|
}
|
|
named, ok := obj.(*types.TypeName).Type().(*types.Named)
|
|
if !ok {
|
|
t.Fatalf("%q is not a named type", typeName)
|
|
}
|
|
var buf bytes.Buffer
|
|
tracker := codegen.NewImportTracker(pkg)
|
|
genView(&buf, tracker, named, nil)
|
|
return buf.String()
|
|
}
|
|
|
|
func TestViewerImports(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
content string
|
|
typeNames []string
|
|
wantImports [][2]string
|
|
}{
|
|
{
|
|
name: "Map",
|
|
content: `type Test struct { Map map[string]int }`,
|
|
typeNames: []string{"Test"},
|
|
wantImports: [][2]string{{"", "tailscale.com/types/views"}},
|
|
},
|
|
{
|
|
name: "Slice",
|
|
content: `type Test struct { Slice []int }`,
|
|
typeNames: []string{"Test"},
|
|
wantImports: [][2]string{{"", "tailscale.com/types/views"}},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
fset := token.NewFileSet()
|
|
f, err := parser.ParseFile(fset, "test.go", "package test\n\n"+tt.content, 0)
|
|
if err != nil {
|
|
fmt.Println("Error parsing:", err)
|
|
return
|
|
}
|
|
|
|
info := &types.Info{
|
|
Types: make(map[ast.Expr]types.TypeAndValue),
|
|
}
|
|
|
|
conf := types.Config{}
|
|
pkg, err := conf.Check("", fset, []*ast.File{f}, info)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var fieldComments map[fieldNameKey]string // don't need it for this test.
|
|
|
|
var output bytes.Buffer
|
|
tracker := codegen.NewImportTracker(pkg)
|
|
for i := range tt.typeNames {
|
|
typeName, ok := pkg.Scope().Lookup(tt.typeNames[i]).(*types.TypeName)
|
|
if !ok {
|
|
t.Fatalf("type %q does not exist", tt.typeNames[i])
|
|
}
|
|
namedType, ok := typeName.Type().(*types.Named)
|
|
if !ok {
|
|
t.Fatalf("%q is not a named type", tt.typeNames[i])
|
|
}
|
|
genView(&output, tracker, namedType, fieldComments)
|
|
}
|
|
|
|
for _, pkg := range tt.wantImports {
|
|
if !tracker.Has(pkg[0], pkg[1]) {
|
|
t.Errorf("missing import %q", pkg)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|