mirror of
				https://github.com/siderolabs/talos.git
				synced 2025-11-04 10:21:13 +01:00 
			
		
		
		
	This commit adds structprotogen tool which is used to generate proto file from Go structs. Closes #6078. Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
		
			
				
	
	
		
			251 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			251 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// This Source Code Form is subject to the terms of the Mozilla Public
 | 
						|
// License, v. 2.0. If a copy of the MPL was not distributed with this
 | 
						|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
 | 
						|
 | 
						|
// Package ast is used to find all structs with expected tag in the given AST.
 | 
						|
package ast
 | 
						|
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"go/ast"
 | 
						|
	"go/token"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"unicode"
 | 
						|
 | 
						|
	"github.com/fatih/structtag"
 | 
						|
	"golang.org/x/tools/go/packages"
 | 
						|
)
 | 
						|
 | 
						|
// FindAllTaggedStructs extracts all structs with comments which contain tag in the given AST.
 | 
						|
func FindAllTaggedStructs(pkgs []*packages.Package) TaggedStructs {
 | 
						|
	result := TaggedStructs{}
 | 
						|
 | 
						|
	for _, pkg := range pkgs {
 | 
						|
		for _, file := range pkg.Syntax {
 | 
						|
			structs := findAllStructs(file)
 | 
						|
 | 
						|
			for _, structDef := range structs {
 | 
						|
				result.Add(pkg.PkgPath, structDef.Struct.Name.Name, TaggedStruct{
 | 
						|
					Comments: formatComments(structDef.Comment),
 | 
						|
					Fields:   structDef.Fields,
 | 
						|
				})
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return result
 | 
						|
}
 | 
						|
 | 
						|
// TaggedStructs is a map of tagged struct declarations to their data.
 | 
						|
type TaggedStructs map[StructDecl]TaggedStruct
 | 
						|
 | 
						|
// TaggedStruct contains struct comments and fields.
 | 
						|
type TaggedStruct struct {
 | 
						|
	Comments []string
 | 
						|
	Fields   Fields
 | 
						|
}
 | 
						|
 | 
						|
// Get returns the struct data for given pkg and pkgPath.
 | 
						|
func (t TaggedStructs) Get(pkg, name string) (TaggedStruct, bool) {
 | 
						|
	val, ok := t[StructDecl{pkg, name}]
 | 
						|
 | 
						|
	return val, ok
 | 
						|
}
 | 
						|
 | 
						|
// Add adds the given struct data to the map.
 | 
						|
func (t TaggedStructs) Add(pkg, name string, structData TaggedStruct) {
 | 
						|
	t[StructDecl{Pkg: pkg, Name: name}] = structData
 | 
						|
}
 | 
						|
 | 
						|
// findAllStructs finds all structs with comments in the given AST.
 | 
						|
func findAllStructs(f ast.Node) []structData {
 | 
						|
	var result []structData
 | 
						|
 | 
						|
	ast.Inspect(f, func(n ast.Node) bool {
 | 
						|
		genDecl, ok := n.(*ast.GenDecl)
 | 
						|
		if !ok || genDecl.Tok != token.TYPE {
 | 
						|
			return true
 | 
						|
		}
 | 
						|
 | 
						|
		typeSpecs := filter(genDecl.Specs, typeSpecFilter)
 | 
						|
 | 
						|
		switch len(typeSpecs) {
 | 
						|
		case 0:
 | 
						|
			return false
 | 
						|
		case 1:
 | 
						|
			switch {
 | 
						|
			case isTarget(genDecl.Doc):
 | 
						|
				fields := getStructFieldsWithTags(typeSpecs[0])
 | 
						|
 | 
						|
				result = append(result, structData{
 | 
						|
					Struct:  typeSpecs[0],
 | 
						|
					Comment: genDecl.Doc,
 | 
						|
					Fields:  fields,
 | 
						|
				})
 | 
						|
			case isTarget(typeSpecs[0].Doc):
 | 
						|
				fields := getStructFieldsWithTags(typeSpecs[0])
 | 
						|
 | 
						|
				result = append(result, structData{
 | 
						|
					Struct:  typeSpecs[0],
 | 
						|
					Comment: typeSpecs[0].Doc,
 | 
						|
					Fields:  fields,
 | 
						|
				})
 | 
						|
			}
 | 
						|
 | 
						|
			return false
 | 
						|
		default:
 | 
						|
			for _, typeSpec := range typeSpecs {
 | 
						|
				if isTarget(typeSpec.Doc) {
 | 
						|
					fields := getStructFieldsWithTags(typeSpec)
 | 
						|
 | 
						|
					result = append(result, structData{
 | 
						|
						Struct:  typeSpec,
 | 
						|
						Comment: typeSpec.Doc,
 | 
						|
						Fields:  fields,
 | 
						|
					})
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			return false
 | 
						|
		}
 | 
						|
	})
 | 
						|
 | 
						|
	return result
 | 
						|
}
 | 
						|
 | 
						|
type structData struct {
 | 
						|
	Struct  *ast.TypeSpec
 | 
						|
	Comment *ast.CommentGroup
 | 
						|
	Fields  Fields
 | 
						|
}
 | 
						|
 | 
						|
func typeSpecFilter(n ast.Spec) (*ast.TypeSpec, bool) {
 | 
						|
	typeSpec, ok := n.(*ast.TypeSpec)
 | 
						|
	if !ok {
 | 
						|
		return nil, false
 | 
						|
	}
 | 
						|
 | 
						|
	structType, ok := typeSpec.Type.(*ast.StructType)
 | 
						|
	if !ok {
 | 
						|
		return nil, false
 | 
						|
	}
 | 
						|
 | 
						|
	if structType.Fields.NumFields() == 0 {
 | 
						|
		return nil, false
 | 
						|
	}
 | 
						|
 | 
						|
	if !isCapitalCase(typeSpec.Name.Name) {
 | 
						|
		return nil, false
 | 
						|
	}
 | 
						|
 | 
						|
	return typeSpec, true
 | 
						|
}
 | 
						|
 | 
						|
// isCapitalCase returns true if the given string is in capital case.
 | 
						|
func isCapitalCase(s string) bool {
 | 
						|
	return len(s) > 0 && unicode.IsUpper(rune(s[0]))
 | 
						|
}
 | 
						|
 | 
						|
func isTarget(doc *ast.CommentGroup) bool {
 | 
						|
	if doc == nil {
 | 
						|
		return false
 | 
						|
	}
 | 
						|
 | 
						|
	for _, c := range doc.List {
 | 
						|
		if isTargetComment(c.Text) {
 | 
						|
			return true
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
func isTargetComment(str string) bool {
 | 
						|
	return str == "//gotagsrewrite:gen"
 | 
						|
}
 | 
						|
 | 
						|
func filter[T, V any](slc []T, f func(n T) (V, bool)) []V {
 | 
						|
	var result []V
 | 
						|
 | 
						|
	for _, v := range slc {
 | 
						|
		res, ok := f(v)
 | 
						|
		if ok {
 | 
						|
			result = append(result, res)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return result
 | 
						|
}
 | 
						|
 | 
						|
// StructDecl is a struct declaration.
 | 
						|
type StructDecl struct {
 | 
						|
	Pkg  string
 | 
						|
	Name string
 | 
						|
}
 | 
						|
 | 
						|
func formatComments(comment *ast.CommentGroup) []string {
 | 
						|
	if comment == nil {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	result := make([]string, 0, len(comment.List))
 | 
						|
 | 
						|
	for i := range comment.List {
 | 
						|
		if len(comment.List)-1 >= i+1 && isTargetComment(comment.List[i+1].Text) {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if isTargetComment(comment.List[i].Text) {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		result = append(result, comment.List[i].Text)
 | 
						|
	}
 | 
						|
 | 
						|
	return result
 | 
						|
}
 | 
						|
 | 
						|
// Fields represents a struct field and its protobuf number.
 | 
						|
type Fields map[string]int
 | 
						|
 | 
						|
// getStructFieldsWithTags returns all fields of the given struct with their tags.
 | 
						|
func getStructFieldsWithTags(structDecl *ast.TypeSpec) Fields {
 | 
						|
	result := Fields{}
 | 
						|
 | 
						|
	structType := structDecl.Type.(*ast.StructType) //nolint:errcheck
 | 
						|
 | 
						|
	for _, field := range structType.Fields.List {
 | 
						|
		if field.Names == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		for _, name := range field.Names {
 | 
						|
			if field.Tag == nil {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			tagValue := strings.Trim(field.Tag.Value, "`")
 | 
						|
 | 
						|
			tags, err := structtag.Parse(tagValue)
 | 
						|
			if err != nil {
 | 
						|
				panic(fmt.Errorf("invalid tag: field '%s', tag '%s': %w", name, tagValue, err))
 | 
						|
			}
 | 
						|
 | 
						|
			tag, err := tags.Get("protobuf")
 | 
						|
			if err != nil {
 | 
						|
				panic(fmt.Errorf("cannot find protobuf tag: field '%s', tag '%s': %w", name, tagValue, err))
 | 
						|
			}
 | 
						|
 | 
						|
			num, err := strconv.Atoi(tag.Name)
 | 
						|
			if err != nil {
 | 
						|
				panic(fmt.Errorf("invalid protobuf tag: field '%s', tag '%s': %w", name, tagValue, err))
 | 
						|
			}
 | 
						|
 | 
						|
			result[name.Name] = num
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return result
 | 
						|
}
 |