talos/hack/structprotogen/ast/ast.go
Dmitriy Matrenichev bd56621cdf
feat: add structprotogen tool
This commit adds structprotogen tool which is used to generate proto file from Go structs.

Closes #6078.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
2022-09-05 16:54:00 +03:00

251 lines
5.2 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Package ast is used to find all structs with expected tag in the given AST.
package ast
import (
"fmt"
"go/ast"
"go/token"
"strconv"
"strings"
"unicode"
"github.com/fatih/structtag"
"golang.org/x/tools/go/packages"
)
// FindAllTaggedStructs extracts all structs with comments which contain tag in the given AST.
func FindAllTaggedStructs(pkgs []*packages.Package) TaggedStructs {
result := TaggedStructs{}
for _, pkg := range pkgs {
for _, file := range pkg.Syntax {
structs := findAllStructs(file)
for _, structDef := range structs {
result.Add(pkg.PkgPath, structDef.Struct.Name.Name, TaggedStruct{
Comments: formatComments(structDef.Comment),
Fields: structDef.Fields,
})
}
}
}
return result
}
// TaggedStructs is a map of tagged struct declarations to their data.
type TaggedStructs map[StructDecl]TaggedStruct
// TaggedStruct contains struct comments and fields.
type TaggedStruct struct {
Comments []string
Fields Fields
}
// Get returns the struct data for given pkg and pkgPath.
func (t TaggedStructs) Get(pkg, name string) (TaggedStruct, bool) {
val, ok := t[StructDecl{pkg, name}]
return val, ok
}
// Add adds the given struct data to the map.
func (t TaggedStructs) Add(pkg, name string, structData TaggedStruct) {
t[StructDecl{Pkg: pkg, Name: name}] = structData
}
// findAllStructs finds all structs with comments in the given AST.
func findAllStructs(f ast.Node) []structData {
var result []structData
ast.Inspect(f, func(n ast.Node) bool {
genDecl, ok := n.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
return true
}
typeSpecs := filter(genDecl.Specs, typeSpecFilter)
switch len(typeSpecs) {
case 0:
return false
case 1:
switch {
case isTarget(genDecl.Doc):
fields := getStructFieldsWithTags(typeSpecs[0])
result = append(result, structData{
Struct: typeSpecs[0],
Comment: genDecl.Doc,
Fields: fields,
})
case isTarget(typeSpecs[0].Doc):
fields := getStructFieldsWithTags(typeSpecs[0])
result = append(result, structData{
Struct: typeSpecs[0],
Comment: typeSpecs[0].Doc,
Fields: fields,
})
}
return false
default:
for _, typeSpec := range typeSpecs {
if isTarget(typeSpec.Doc) {
fields := getStructFieldsWithTags(typeSpec)
result = append(result, structData{
Struct: typeSpec,
Comment: typeSpec.Doc,
Fields: fields,
})
}
}
return false
}
})
return result
}
type structData struct {
Struct *ast.TypeSpec
Comment *ast.CommentGroup
Fields Fields
}
func typeSpecFilter(n ast.Spec) (*ast.TypeSpec, bool) {
typeSpec, ok := n.(*ast.TypeSpec)
if !ok {
return nil, false
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
return nil, false
}
if structType.Fields.NumFields() == 0 {
return nil, false
}
if !isCapitalCase(typeSpec.Name.Name) {
return nil, false
}
return typeSpec, true
}
// isCapitalCase returns true if the given string is in capital case.
func isCapitalCase(s string) bool {
return len(s) > 0 && unicode.IsUpper(rune(s[0]))
}
func isTarget(doc *ast.CommentGroup) bool {
if doc == nil {
return false
}
for _, c := range doc.List {
if isTargetComment(c.Text) {
return true
}
}
return false
}
func isTargetComment(str string) bool {
return str == "//gotagsrewrite:gen"
}
func filter[T, V any](slc []T, f func(n T) (V, bool)) []V {
var result []V
for _, v := range slc {
res, ok := f(v)
if ok {
result = append(result, res)
}
}
return result
}
// StructDecl is a struct declaration.
type StructDecl struct {
Pkg string
Name string
}
func formatComments(comment *ast.CommentGroup) []string {
if comment == nil {
return nil
}
result := make([]string, 0, len(comment.List))
for i := range comment.List {
if len(comment.List)-1 >= i+1 && isTargetComment(comment.List[i+1].Text) {
continue
}
if isTargetComment(comment.List[i].Text) {
continue
}
result = append(result, comment.List[i].Text)
}
return result
}
// Fields represents a struct field and its protobuf number.
type Fields map[string]int
// getStructFieldsWithTags returns all fields of the given struct with their tags.
func getStructFieldsWithTags(structDecl *ast.TypeSpec) Fields {
result := Fields{}
structType := structDecl.Type.(*ast.StructType) //nolint:errcheck
for _, field := range structType.Fields.List {
if field.Names == nil {
continue
}
for _, name := range field.Names {
if field.Tag == nil {
continue
}
tagValue := strings.Trim(field.Tag.Value, "`")
tags, err := structtag.Parse(tagValue)
if err != nil {
panic(fmt.Errorf("invalid tag: field '%s', tag '%s': %w", name, tagValue, err))
}
tag, err := tags.Get("protobuf")
if err != nil {
panic(fmt.Errorf("cannot find protobuf tag: field '%s', tag '%s': %w", name, tagValue, err))
}
num, err := strconv.Atoi(tag.Name)
if err != nil {
panic(fmt.Errorf("invalid protobuf tag: field '%s', tag '%s': %w", name, tagValue, err))
}
result[name.Name] = num
}
}
return result
}