mirror of
				https://github.com/siderolabs/talos.git
				synced 2025-11-04 10:21:13 +01:00 
			
		
		
		
	`structprotogen` now supports generating enums directly instead of using predeclared file and hardcoded types. To use this functionality, simply put `structprotogen:gen_enum` in the comment above const block, you want to have the proto definitions for. Closes #6215 Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
		
			
				
	
	
		
			493 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			493 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// This Source Code Form is subject to the terms of the Mozilla Public
 | 
						|
// License, v. 2.0. If a copy of the MPL was not distributed with this
 | 
						|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
 | 
						|
 | 
						|
// Package proto contains the protobuf generation logic.
 | 
						|
package proto
 | 
						|
 | 
						|
//nolint:gci
 | 
						|
import (
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"regexp"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"gopkg.in/typ.v4/slices"
 | 
						|
 | 
						|
	"github.com/siderolabs/structprotogen/consts"
 | 
						|
	"github.com/siderolabs/structprotogen/sliceutil"
 | 
						|
	"github.com/siderolabs/structprotogen/types"
 | 
						|
)
 | 
						|
 | 
						|
// Pkg represents a protobuf package.
 | 
						|
type Pkg struct {
 | 
						|
	Name  string
 | 
						|
	GoPkg string
 | 
						|
 | 
						|
	isInit    bool
 | 
						|
	protoDefs slices.Sorted[*protoDef]
 | 
						|
	imports   slices.Sorted[string]
 | 
						|
}
 | 
						|
 | 
						|
func protoPkgsCmp(left, right *Pkg) int {
 | 
						|
	return strings.Compare(left.Name, right.Name)
 | 
						|
}
 | 
						|
 | 
						|
func (p *Pkg) init() {
 | 
						|
	if !p.isInit {
 | 
						|
		p.protoDefs = slices.NewSortedCompare([]*protoDef{}, protoDefCmp)
 | 
						|
		p.imports = slices.NewSortedCompare([]string{}, strings.Compare)
 | 
						|
		p.isInit = true
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Defs returns the list of definitions.
 | 
						|
func (p *Pkg) Defs() *slices.Sorted[*protoDef] {
 | 
						|
	p.init()
 | 
						|
 | 
						|
	return &p.protoDefs
 | 
						|
}
 | 
						|
 | 
						|
// Imports returns the list of imports.
 | 
						|
func (p *Pkg) Imports() *slices.Sorted[string] {
 | 
						|
	p.init()
 | 
						|
 | 
						|
	return &p.imports
 | 
						|
}
 | 
						|
 | 
						|
// WriteDebug is like Format, but writes additional debug info.
 | 
						|
func (p *Pkg) WriteDebug(w io.Writer) {
 | 
						|
	pkgName := p.Name
 | 
						|
 | 
						|
	fmt.Fprint(w, "syntax = \"proto3\";\n\n")
 | 
						|
	fmt.Fprintf(w, "package talos.resource.definitions.%s; // %s\n\n", p.Name, p.GoPkg)
 | 
						|
	fmt.Fprintf(w, "option go_package = \"github.com/siderolabs/talos/pkg/machinery/api/resource/definitions/%s\";\n\n", pkgName) // TODO: insert proper path
 | 
						|
 | 
						|
	if p.imports.Len() > 0 {
 | 
						|
		for i := 0; i < p.imports.Len(); i++ {
 | 
						|
			importPath := p.imports.Get(i)
 | 
						|
			if !strings.ContainsRune(importPath, '.') {
 | 
						|
				importPath = "talos.resource.definitions." + importPath
 | 
						|
			}
 | 
						|
 | 
						|
			fmt.Fprintf(w, "import \"%s\";\n", importPath)
 | 
						|
		}
 | 
						|
 | 
						|
		fmt.Fprintln(w, ``)
 | 
						|
	}
 | 
						|
 | 
						|
	for i := 0; i < p.protoDefs.Len(); i++ {
 | 
						|
		p.protoDefs.Get(i).WriteDebug(w)
 | 
						|
		fmt.Fprintln(w)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Format formats the protobuf data.
 | 
						|
func (p *Pkg) Format(w io.Writer) {
 | 
						|
	pkgName := p.Name
 | 
						|
 | 
						|
	fmt.Fprint(w, "syntax = \"proto3\";\n\n")
 | 
						|
	fmt.Fprintf(w, "package talos.resource.definitions.%s;\n\n", p.Name)
 | 
						|
	fmt.Fprintf(w, "option go_package = \"github.com/siderolabs/talos/pkg/machinery/api/resource/definitions/%s\";\n\n", pkgName) // TODO: insert proper path
 | 
						|
 | 
						|
	if p.imports.Len() > 0 {
 | 
						|
		for i := 0; i < p.imports.Len(); i++ {
 | 
						|
			importPath := p.imports.Get(i)
 | 
						|
			if !strings.ContainsRune(importPath, '.') {
 | 
						|
				importPath = "talos.resource.definitions." + importPath
 | 
						|
			}
 | 
						|
 | 
						|
			fmt.Fprintf(w, "import \"%s\";\n", importPath)
 | 
						|
		}
 | 
						|
 | 
						|
		fmt.Fprintln(w, ``)
 | 
						|
	}
 | 
						|
 | 
						|
	for i := 0; i < p.protoDefs.Len(); i++ {
 | 
						|
		p.protoDefs.Get(i).Format(w)
 | 
						|
		fmt.Fprintln(w)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type protoDef struct {
 | 
						|
	name string
 | 
						|
 | 
						|
	goPkg    string
 | 
						|
	comments []string
 | 
						|
 | 
						|
	isInit bool
 | 
						|
	fields slices.Sorted[protoField]
 | 
						|
}
 | 
						|
 | 
						|
func protoDefCmp(left, right *protoDef) int {
 | 
						|
	return strings.Compare(left.name, right.name)
 | 
						|
}
 | 
						|
 | 
						|
func (p *protoDef) init() {
 | 
						|
	if !p.isInit {
 | 
						|
		p.fields = slices.NewSortedCompare([]protoField{}, protoFieldCmp)
 | 
						|
		p.isInit = true
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (p *protoDef) Fields() *slices.Sorted[protoField] {
 | 
						|
	p.init()
 | 
						|
 | 
						|
	return &p.fields
 | 
						|
}
 | 
						|
 | 
						|
func (p *protoDef) WriteDebug(w io.Writer) {
 | 
						|
	for _, comment := range p.comments {
 | 
						|
		fmt.Fprintf(w, "%s\n", comment)
 | 
						|
	}
 | 
						|
 | 
						|
	fmt.Fprintf(w, "message %s { //%s.%s\n", p.name, p.goPkg, p.name)
 | 
						|
 | 
						|
	for i := 0; i < p.fields.Len(); i++ {
 | 
						|
		fmt.Fprintf(w, "  ")
 | 
						|
		p.fields.Get(i).WriteDebug(w)
 | 
						|
	}
 | 
						|
 | 
						|
	fmt.Fprintln(w, "}")
 | 
						|
}
 | 
						|
 | 
						|
func (p *protoDef) Format(w io.Writer) {
 | 
						|
	for _, comment := range p.comments {
 | 
						|
		fmt.Fprintf(w, "%s\n", comment)
 | 
						|
	}
 | 
						|
 | 
						|
	fmt.Fprintf(w, "message %s {\n", p.name)
 | 
						|
 | 
						|
	for i := 0; i < p.fields.Len(); i++ {
 | 
						|
		fmt.Fprintf(w, "  ")
 | 
						|
		p.fields.Get(i).Format(w)
 | 
						|
	}
 | 
						|
 | 
						|
	fmt.Fprintln(w, "}")
 | 
						|
}
 | 
						|
 | 
						|
type protoField struct {
 | 
						|
	name string
 | 
						|
	typ  string
 | 
						|
	num  int
 | 
						|
 | 
						|
	goType string
 | 
						|
}
 | 
						|
 | 
						|
func protoFieldCmp(left, right protoField) int {
 | 
						|
	if left.num == 0 {
 | 
						|
		panic(fmt.Errorf("left field '%s' has no number", left.name))
 | 
						|
	}
 | 
						|
 | 
						|
	if right.num == 0 {
 | 
						|
		panic(fmt.Errorf("right field '%s' has no number", right.name))
 | 
						|
	}
 | 
						|
 | 
						|
	switch {
 | 
						|
	case left.num < right.num:
 | 
						|
		return -1
 | 
						|
	case left.num > right.num:
 | 
						|
		return 1
 | 
						|
	default:
 | 
						|
		return 0
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (pf protoField) WriteDebug(w io.Writer) {
 | 
						|
	fmt.Fprintf(w, "%s %s = %d; // %s \n", pf.typ, ToSnakeCase(pf.name), pf.num, pf.goType)
 | 
						|
}
 | 
						|
 | 
						|
func (pf protoField) Format(w io.Writer) {
 | 
						|
	fmt.Fprintf(w, "%s %s = %d;\n", pf.typ, ToSnakeCase(pf.name), pf.num)
 | 
						|
}
 | 
						|
 | 
						|
// PrepareProtoData prepares the data for the protobuf generation.
 | 
						|
//
 | 
						|
//nolint:gocyclo,cyclop
 | 
						|
func PrepareProtoData(pkgsTypes slices.Sorted[*types.Type], constants consts.ConstBlocks) slices.Sorted[*Pkg] {
 | 
						|
	result := slices.NewSortedCompare([]*Pkg{}, protoPkgsCmp)
 | 
						|
 | 
						|
	for i := 0; i < pkgsTypes.Len(); i++ {
 | 
						|
		pkgType := pkgsTypes.Get(i)
 | 
						|
 | 
						|
		protoPkg := sliceutil.GetOrAdd(&result, &Pkg{
 | 
						|
			Name:  pkgType.PkgName(),
 | 
						|
			GoPkg: pkgType.Pkg,
 | 
						|
		})
 | 
						|
 | 
						|
		def := sliceutil.GetOrAdd(protoPkg.Defs(), &protoDef{
 | 
						|
			name:     pkgType.Name,
 | 
						|
			goPkg:    pkgType.Pkg,
 | 
						|
			comments: pkgType.Comments,
 | 
						|
		})
 | 
						|
 | 
						|
		for j := 0; j < pkgType.Fields().Len(); j++ {
 | 
						|
			field := pkgType.Fields().Get(j)
 | 
						|
 | 
						|
			fieldTypeData := types.TypeInfo(field.TypeData.Type())
 | 
						|
 | 
						|
			if fieldTyp, ok := types.MatchTypeData[types.Complex](fieldTypeData); ok {
 | 
						|
				importName, typeName := mustFormatTypeName(fieldTyp.Pkg, fieldTyp.Name, pkgType.Pkg)
 | 
						|
 | 
						|
				if importName != "" {
 | 
						|
					sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
 | 
						|
				}
 | 
						|
 | 
						|
				sliceutil.AddIfNotFound(def.Fields(), protoField{
 | 
						|
					name:   field.Name,
 | 
						|
					typ:    typeName,
 | 
						|
					num:    field.Num,
 | 
						|
					goType: field.TypeData.Type().String(),
 | 
						|
				})
 | 
						|
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if fieldTyp, ok := types.MatchTypeData[types.Basic](fieldTypeData); ok {
 | 
						|
				var importName, typeName string
 | 
						|
 | 
						|
				if block, ok := constants.Get(fieldTyp.Pkg, fieldTyp.Name); ok {
 | 
						|
					importName = "resource/definitions/enums/enums.proto"
 | 
						|
					typeName = "talos.resource.definitions.enums." + block.ProtoMessageName()
 | 
						|
				} else {
 | 
						|
					importName, typeName = mustFormatBasicTypeName(fieldTyp.Pkg, fieldTyp.Name)
 | 
						|
				}
 | 
						|
 | 
						|
				if importName != "" {
 | 
						|
					sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
 | 
						|
				}
 | 
						|
 | 
						|
				sliceutil.AddIfNotFound(def.Fields(), protoField{
 | 
						|
					name:   field.Name,
 | 
						|
					typ:    typeName,
 | 
						|
					num:    field.Num,
 | 
						|
					goType: field.TypeData.Type().String(),
 | 
						|
				})
 | 
						|
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if fieldTyp, ok := types.MatchTypeData[types.Slice](fieldTypeData); ok {
 | 
						|
				var importName, typeName string
 | 
						|
 | 
						|
				switch {
 | 
						|
				case fieldTyp.Pkg == "" && fieldTyp.Name == "byte" && fieldTyp.Is2DSlice: //nolint:goconst
 | 
						|
					typeName = "repeated bytes"
 | 
						|
				case fieldTyp.Pkg == "" && fieldTyp.Name == "byte":
 | 
						|
					typeName = "bytes"
 | 
						|
				case fieldTyp.Pkg == "":
 | 
						|
					typeName = fmt.Sprintf("repeated %s", getProtoBasicName(fieldTyp.Name))
 | 
						|
				default:
 | 
						|
					importName, typeName = mustFormatTypeName(fieldTyp.Pkg, fieldTyp.Name, pkgType.Pkg)
 | 
						|
					typeName = fmt.Sprintf("repeated %s", typeName)
 | 
						|
				}
 | 
						|
 | 
						|
				if importName != "" {
 | 
						|
					sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
 | 
						|
				}
 | 
						|
 | 
						|
				sliceutil.AddIfNotFound(def.Fields(), protoField{
 | 
						|
					name:   field.Name,
 | 
						|
					typ:    typeName,
 | 
						|
					num:    field.Num,
 | 
						|
					goType: field.TypeData.Type().String(),
 | 
						|
				})
 | 
						|
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if fieldTyp, ok := types.MatchTypeData[types.Map](fieldTypeData); ok {
 | 
						|
				// key cannot be anything but a basic type
 | 
						|
				importKeyName, keyTypeName := mustFormatBasicTypeName(fieldTyp.KeyTypePkg, fieldTyp.KeyTypeName)
 | 
						|
				if importKeyName != "" {
 | 
						|
					panic(fmt.Errorf("map key type '%s.%s' is not basic type", fieldTyp.KeyTypePkg, fieldTyp.KeyTypeName))
 | 
						|
				}
 | 
						|
 | 
						|
				var (
 | 
						|
					typText    string
 | 
						|
					importElem string
 | 
						|
				)
 | 
						|
 | 
						|
				switch {
 | 
						|
				case fieldTyp.ElemTypeName == "interface{}": // handle map[key]interface{}
 | 
						|
					importElem = "google/protobuf/struct.proto"
 | 
						|
					typText = "google.protobuf.Struct"
 | 
						|
				case fieldTyp.ElemTypePkg == "":
 | 
						|
					elemTypeName := getProtoBasicName(fieldTyp.ElemTypeName)
 | 
						|
					typText = fmt.Sprintf("map<%s, %s>", keyTypeName, elemTypeName)
 | 
						|
				case fieldTyp.ElemTypePkg == pkgType.Pkg:
 | 
						|
					var elemTypeName string
 | 
						|
					importElem, elemTypeName = mustFormatTypeName(fieldTyp.ElemTypePkg, fieldTyp.ElemTypeName, pkgType.Pkg)
 | 
						|
					typText = fmt.Sprintf("map<%s, %s>", keyTypeName, elemTypeName)
 | 
						|
				default:
 | 
						|
					panic(fmt.Errorf("map value type '%s.%s' is not known type", fieldTyp.ElemTypePkg, fieldTyp.ElemTypeName))
 | 
						|
				}
 | 
						|
 | 
						|
				if importElem != "" {
 | 
						|
					sliceutil.AddIfNotFound(protoPkg.Imports(), importElem)
 | 
						|
				}
 | 
						|
 | 
						|
				sliceutil.AddIfNotFound(def.Fields(), protoField{
 | 
						|
					name:   field.Name,
 | 
						|
					typ:    typText,
 | 
						|
					num:    field.Num,
 | 
						|
					goType: field.TypeData.Type().String(),
 | 
						|
				})
 | 
						|
 | 
						|
				continue
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return result
 | 
						|
}
 | 
						|
 | 
						|
func mustFormatTypeName(fieldTypePkg string, fieldType string, declPkg string) (string, string) {
 | 
						|
	importPath, name := formatTypeName(fieldTypePkg, fieldType, declPkg)
 | 
						|
	if name == "" {
 | 
						|
		panic(fmt.Errorf("unknown type %s.%s", fieldType, fieldTypePkg))
 | 
						|
	}
 | 
						|
 | 
						|
	return importPath, name
 | 
						|
}
 | 
						|
 | 
						|
func formatTypeName(fieldTypePkg string, fieldType string, declPkg string) (string, string) {
 | 
						|
	if fieldTypePkg == declPkg {
 | 
						|
		return "", fieldType
 | 
						|
	}
 | 
						|
 | 
						|
	type typeData struct {
 | 
						|
		pkg  string
 | 
						|
		name string
 | 
						|
	}
 | 
						|
 | 
						|
	td := typeData{
 | 
						|
		name: fieldType,
 | 
						|
		pkg:  fieldTypePkg,
 | 
						|
	}
 | 
						|
 | 
						|
	const commoProto = "common/common.proto"
 | 
						|
 | 
						|
	switch td {
 | 
						|
	case typeData{"time", "Time"}:
 | 
						|
		return "google/protobuf/timestamp.proto", "google.protobuf.Timestamp"
 | 
						|
	case typeData{"net/url", "URL"}:
 | 
						|
		return commoProto, "common.URL"
 | 
						|
	case typeData{"net/netip", "Prefix"}:
 | 
						|
		return commoProto, "common.NetIPPrefix"
 | 
						|
	case typeData{"net/netip", "AddrPort"}:
 | 
						|
		return commoProto, "common.NetIPPort"
 | 
						|
	case typeData{"net/netip", "Addr"}:
 | 
						|
		return commoProto, "common.NetIP"
 | 
						|
	case typeData{"github.com/opencontainers/runtime-spec/specs-go", "Mount"}:
 | 
						|
		return "resource/definitions/proto/proto.proto", "talos.resource.definitions.proto.Mount"
 | 
						|
	case typeData{"github.com/siderolabs/crypto/x509", "PEMEncodedCertificateAndKey"}:
 | 
						|
		return commoProto, "common.PEMEncodedCertificateAndKey"
 | 
						|
	case typeData{"github.com/siderolabs/crypto/x509", "PEMEncodedKey"}:
 | 
						|
		return commoProto, "common.PEMEncodedKey"
 | 
						|
	default:
 | 
						|
		return "", ""
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func mustFormatBasicTypeName(fieldTypePkg string, fieldType string) (string, string) {
 | 
						|
	if fieldTypePkg == "" {
 | 
						|
		return "", getProtoBasicName(fieldType)
 | 
						|
	}
 | 
						|
 | 
						|
	importPath, fullName := formatBasicTypeName(fieldTypePkg, fieldType)
 | 
						|
	if fullName == "" {
 | 
						|
		panic(fmt.Errorf("unknown type %s.%s", fieldTypePkg, fieldType))
 | 
						|
	}
 | 
						|
 | 
						|
	return importPath, fullName
 | 
						|
}
 | 
						|
 | 
						|
// IsSupportedExternalType checks if external type is supported.
 | 
						|
func IsSupportedExternalType(typ types.ExternalType) bool {
 | 
						|
	if _, name := formatBasicTypeName(typ.Pkg, typ.Name); name != "" {
 | 
						|
		return true
 | 
						|
	}
 | 
						|
 | 
						|
	if _, name := formatTypeName(typ.Pkg, typ.Name, ""); name != "" {
 | 
						|
		return true
 | 
						|
	}
 | 
						|
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
//nolint:gocyclo,cyclop
 | 
						|
func formatBasicTypeName(typPkg string, typ string) (importPath, fullName string) {
 | 
						|
	type typeData struct {
 | 
						|
		pkg  string
 | 
						|
		name string
 | 
						|
	}
 | 
						|
 | 
						|
	td := typeData{
 | 
						|
		name: typ,
 | 
						|
		pkg:  typPkg,
 | 
						|
	}
 | 
						|
 | 
						|
	switch td {
 | 
						|
	case typeData{"time", "Duration"}:
 | 
						|
		return "google/protobuf/duration.proto", "google.protobuf.Duration"
 | 
						|
	case typeData{"io/fs", "FileMode"}:
 | 
						|
		return "", "uint32" //nolint:goconst
 | 
						|
	case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "AddressFlags"}:
 | 
						|
		return "", "uint32"
 | 
						|
	case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "LinkFlags"}:
 | 
						|
		return "", "uint32"
 | 
						|
	case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "RouteFlags"}:
 | 
						|
		return "", "uint32"
 | 
						|
	default:
 | 
						|
		return "", ""
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
//nolint:gocyclo
 | 
						|
func getProtoBasicName(typ string) string {
 | 
						|
	switch typ {
 | 
						|
	case "bool":
 | 
						|
		return "bool"
 | 
						|
	case "int8", "int16":
 | 
						|
		return "fixed32"
 | 
						|
	case "int32":
 | 
						|
		return "int32"
 | 
						|
	case "int64", "int":
 | 
						|
		return "int64"
 | 
						|
	case "byte", "uint8", "uint16":
 | 
						|
		return "fixed32"
 | 
						|
	case "uint32":
 | 
						|
		return "uint32"
 | 
						|
	case "uint64", "uint":
 | 
						|
		return "uint64"
 | 
						|
	case "float32":
 | 
						|
		return "float"
 | 
						|
	case "float64":
 | 
						|
		return "double"
 | 
						|
	case "string":
 | 
						|
		return "string"
 | 
						|
	default:
 | 
						|
		panic(fmt.Sprintf("unknown type %s", typ))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// ToSnakeCase converts a string to snake case.
 | 
						|
func ToSnakeCase(str string) string {
 | 
						|
	snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
 | 
						|
	snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
 | 
						|
	snake = strings.ToLower(snake)
 | 
						|
 | 
						|
	// special case for "SomethingsIps"
 | 
						|
	if strings.HasSuffix(snake, "_i_ps") {
 | 
						|
		snake = strings.TrimSuffix(snake, "_i_ps") + "_ips"
 | 
						|
	}
 | 
						|
 | 
						|
	return snake
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
 | 
						|
	matchAllCap   = regexp.MustCompile("([a-z0-9])([A-Z])")
 | 
						|
)
 |