mirror of
https://github.com/siderolabs/talos.git
synced 2025-08-16 03:27:12 +02:00
There's a cyclic dependency on siderolink library which imports talos machinery back. We will fix that after we get talos pushed under a new name. Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
533 lines
16 KiB
Go
533 lines
16 KiB
Go
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
// Package proto contains the protobuf generation logic.
|
|
package proto
|
|
|
|
//nolint:gci
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"gopkg.in/typ.v4/slices"
|
|
|
|
"github.com/siderolabs/structprotogen/sliceutil"
|
|
"github.com/siderolabs/structprotogen/types"
|
|
)
|
|
|
|
// Pkg represents a protobuf package.
|
|
type Pkg struct {
|
|
Name string
|
|
GoPkg string
|
|
|
|
isInit bool
|
|
protoDefs slices.Sorted[*protoDef]
|
|
imports slices.Sorted[string]
|
|
}
|
|
|
|
func protoPkgsCmp(left, right *Pkg) int {
|
|
return strings.Compare(left.Name, right.Name)
|
|
}
|
|
|
|
func (p *Pkg) init() {
|
|
if !p.isInit {
|
|
p.protoDefs = slices.NewSortedCompare([]*protoDef{}, protoDefCmp)
|
|
p.imports = slices.NewSortedCompare([]string{}, strings.Compare)
|
|
p.isInit = true
|
|
}
|
|
}
|
|
|
|
// Defs returns the list of definitions.
|
|
func (p *Pkg) Defs() *slices.Sorted[*protoDef] {
|
|
p.init()
|
|
|
|
return &p.protoDefs
|
|
}
|
|
|
|
// Imports returns the list of imports.
|
|
func (p *Pkg) Imports() *slices.Sorted[string] {
|
|
p.init()
|
|
|
|
return &p.imports
|
|
}
|
|
|
|
// WriteDebug is like Format, but writes additional debug info.
|
|
func (p *Pkg) WriteDebug(w io.Writer) {
|
|
pkgName := p.Name
|
|
|
|
fmt.Fprint(w, "syntax = \"proto3\";\n\n")
|
|
fmt.Fprintf(w, "package talos.resource.definitions.%s; // %s\n\n", p.Name, p.GoPkg)
|
|
fmt.Fprintf(w, "option go_package = \"github.com/siderolabs/talos/pkg/machinery/api/resource/definitions/%s\";\n\n", pkgName) // TODO: insert proper path
|
|
|
|
if p.imports.Len() > 0 {
|
|
for i := 0; i < p.imports.Len(); i++ {
|
|
importPath := p.imports.Get(i)
|
|
if !strings.ContainsRune(importPath, '.') {
|
|
importPath = "talos.resource.definitions." + importPath
|
|
}
|
|
|
|
fmt.Fprintf(w, "import \"%s\";\n", importPath)
|
|
}
|
|
|
|
fmt.Fprintln(w, ``)
|
|
}
|
|
|
|
for i := 0; i < p.protoDefs.Len(); i++ {
|
|
p.protoDefs.Get(i).WriteDebug(w)
|
|
fmt.Fprintln(w)
|
|
}
|
|
}
|
|
|
|
// Format formats the protobuf data.
|
|
func (p *Pkg) Format(w io.Writer) {
|
|
pkgName := p.Name
|
|
|
|
fmt.Fprint(w, "syntax = \"proto3\";\n\n")
|
|
fmt.Fprintf(w, "package talos.resource.definitions.%s;\n\n", p.Name)
|
|
fmt.Fprintf(w, "option go_package = \"github.com/siderolabs/talos/pkg/machinery/api/resource/definitions/%s\";\n\n", pkgName) // TODO: insert proper path
|
|
|
|
if p.imports.Len() > 0 {
|
|
for i := 0; i < p.imports.Len(); i++ {
|
|
importPath := p.imports.Get(i)
|
|
if !strings.ContainsRune(importPath, '.') {
|
|
importPath = "talos.resource.definitions." + importPath
|
|
}
|
|
|
|
fmt.Fprintf(w, "import \"%s\";\n", importPath)
|
|
}
|
|
|
|
fmt.Fprintln(w, ``)
|
|
}
|
|
|
|
for i := 0; i < p.protoDefs.Len(); i++ {
|
|
p.protoDefs.Get(i).Format(w)
|
|
fmt.Fprintln(w)
|
|
}
|
|
}
|
|
|
|
type protoDef struct {
|
|
name string
|
|
|
|
goPkg string
|
|
comments []string
|
|
|
|
isInit bool
|
|
fields slices.Sorted[protoField]
|
|
}
|
|
|
|
func protoDefCmp(left, right *protoDef) int {
|
|
return strings.Compare(left.name, right.name)
|
|
}
|
|
|
|
func (p *protoDef) init() {
|
|
if !p.isInit {
|
|
p.fields = slices.NewSortedCompare([]protoField{}, protoFieldCmp)
|
|
p.isInit = true
|
|
}
|
|
}
|
|
|
|
func (p *protoDef) Fields() *slices.Sorted[protoField] {
|
|
p.init()
|
|
|
|
return &p.fields
|
|
}
|
|
|
|
func (p *protoDef) WriteDebug(w io.Writer) {
|
|
for _, comment := range p.comments {
|
|
fmt.Fprintf(w, "%s\n", comment)
|
|
}
|
|
|
|
fmt.Fprintf(w, "message %s { //%s.%s\n", p.name, p.goPkg, p.name)
|
|
|
|
for i := 0; i < p.fields.Len(); i++ {
|
|
fmt.Fprintf(w, " ")
|
|
p.fields.Get(i).WriteDebug(w)
|
|
}
|
|
|
|
fmt.Fprintln(w, "}")
|
|
}
|
|
|
|
func (p *protoDef) Format(w io.Writer) {
|
|
for _, comment := range p.comments {
|
|
fmt.Fprintf(w, "%s\n", comment)
|
|
}
|
|
|
|
fmt.Fprintf(w, "message %s {\n", p.name)
|
|
|
|
for i := 0; i < p.fields.Len(); i++ {
|
|
fmt.Fprintf(w, " ")
|
|
p.fields.Get(i).Format(w)
|
|
}
|
|
|
|
fmt.Fprintln(w, "}")
|
|
}
|
|
|
|
type protoField struct {
|
|
name string
|
|
typ string
|
|
num int
|
|
|
|
goType string
|
|
}
|
|
|
|
func protoFieldCmp(left, right protoField) int {
|
|
if left.num == 0 {
|
|
panic(fmt.Errorf("left field '%s' has no number", left.name))
|
|
}
|
|
|
|
if right.num == 0 {
|
|
panic(fmt.Errorf("right field '%s' has no number", right.name))
|
|
}
|
|
|
|
switch {
|
|
case left.num < right.num:
|
|
return -1
|
|
case left.num > right.num:
|
|
return 1
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func (pf protoField) WriteDebug(w io.Writer) {
|
|
fmt.Fprintf(w, "%s %s = %d; // %s \n", pf.typ, ToSnakeCase(pf.name), pf.num, pf.goType)
|
|
}
|
|
|
|
func (pf protoField) Format(w io.Writer) {
|
|
fmt.Fprintf(w, "%s %s = %d;\n", pf.typ, ToSnakeCase(pf.name), pf.num)
|
|
}
|
|
|
|
// PrepareProtoData prepares the data for the protobuf generation.
|
|
//
|
|
//nolint:gocyclo,cyclop
|
|
func PrepareProtoData(pkgsTypes slices.Sorted[*types.Type]) slices.Sorted[*Pkg] {
|
|
result := slices.NewSortedCompare([]*Pkg{}, protoPkgsCmp)
|
|
|
|
for i := 0; i < pkgsTypes.Len(); i++ {
|
|
pkgType := pkgsTypes.Get(i)
|
|
|
|
protoPkg := sliceutil.GetOrAdd(&result, &Pkg{
|
|
Name: pkgType.PkgName(),
|
|
GoPkg: pkgType.Pkg,
|
|
})
|
|
|
|
def := sliceutil.GetOrAdd(protoPkg.Defs(), &protoDef{
|
|
name: pkgType.Name,
|
|
goPkg: pkgType.Pkg,
|
|
comments: pkgType.Comments,
|
|
})
|
|
|
|
for j := 0; j < pkgType.Fields().Len(); j++ {
|
|
field := pkgType.Fields().Get(j)
|
|
|
|
fieldTypeData := types.TypeInfo(field.TypeData.Type())
|
|
|
|
if fieldTyp, ok := types.MatchTypeData[types.Complex](fieldTypeData); ok {
|
|
importName, typeName := mustFormatTypeName(fieldTyp.Pkg, fieldTyp.Name, pkgType.Pkg)
|
|
|
|
if importName != "" {
|
|
sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
|
|
}
|
|
|
|
sliceutil.AddIfNotFound(def.Fields(), protoField{
|
|
name: field.Name,
|
|
typ: typeName,
|
|
num: field.Num,
|
|
goType: field.TypeData.Type().String(),
|
|
})
|
|
|
|
continue
|
|
}
|
|
|
|
if fieldTyp, ok := types.MatchTypeData[types.Basic](fieldTypeData); ok {
|
|
importName, typeName := mustFormatBasicTypeName(fieldTyp.Pkg, fieldTyp.Name)
|
|
|
|
if importName != "" {
|
|
sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
|
|
}
|
|
|
|
sliceutil.AddIfNotFound(def.Fields(), protoField{
|
|
name: field.Name,
|
|
typ: typeName,
|
|
num: field.Num,
|
|
goType: field.TypeData.Type().String(),
|
|
})
|
|
|
|
continue
|
|
}
|
|
|
|
if fieldTyp, ok := types.MatchTypeData[types.Slice](fieldTypeData); ok {
|
|
var importName, typeName string
|
|
|
|
switch {
|
|
case fieldTyp.Pkg == "" && fieldTyp.Name == "byte" && fieldTyp.Is2DSlice: //nolint:goconst
|
|
typeName = "repeated bytes"
|
|
case fieldTyp.Pkg == "" && fieldTyp.Name == "byte":
|
|
typeName = "bytes"
|
|
case fieldTyp.Pkg == "":
|
|
typeName = fmt.Sprintf("repeated %s", getProtoBasicName(fieldTyp.Name))
|
|
default:
|
|
importName, typeName = mustFormatTypeName(fieldTyp.Pkg, fieldTyp.Name, pkgType.Pkg)
|
|
typeName = fmt.Sprintf("repeated %s", typeName)
|
|
}
|
|
|
|
if importName != "" {
|
|
sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
|
|
}
|
|
|
|
sliceutil.AddIfNotFound(def.Fields(), protoField{
|
|
name: field.Name,
|
|
typ: typeName,
|
|
num: field.Num,
|
|
goType: field.TypeData.Type().String(),
|
|
})
|
|
|
|
continue
|
|
}
|
|
|
|
if fieldTyp, ok := types.MatchTypeData[types.Map](fieldTypeData); ok {
|
|
// key cannot be anything but a basic type
|
|
importKeyName, keyTypeName := mustFormatBasicTypeName(fieldTyp.KeyTypePkg, fieldTyp.KeyTypeName)
|
|
if importKeyName != "" {
|
|
panic(fmt.Errorf("map key type '%s.%s' is not basic type", fieldTyp.KeyTypePkg, fieldTyp.KeyTypeName))
|
|
}
|
|
|
|
var (
|
|
typText string
|
|
importElem string
|
|
)
|
|
|
|
switch {
|
|
case fieldTyp.ElemTypeName == "interface{}": // handle map[key]interface{}
|
|
importElem = "google/protobuf/struct.proto"
|
|
typText = "google.protobuf.Struct"
|
|
case fieldTyp.ElemTypePkg == "":
|
|
elemTypeName := getProtoBasicName(fieldTyp.ElemTypeName)
|
|
typText = fmt.Sprintf("map<%s, %s>", keyTypeName, elemTypeName)
|
|
case fieldTyp.ElemTypePkg == pkgType.Pkg:
|
|
var elemTypeName string
|
|
importElem, elemTypeName = mustFormatTypeName(fieldTyp.ElemTypePkg, fieldTyp.ElemTypeName, pkgType.Pkg)
|
|
typText = fmt.Sprintf("map<%s, %s>", keyTypeName, elemTypeName)
|
|
default:
|
|
panic(fmt.Errorf("map value type '%s.%s' is not known type", fieldTyp.ElemTypePkg, fieldTyp.ElemTypeName))
|
|
}
|
|
|
|
if importElem != "" {
|
|
sliceutil.AddIfNotFound(protoPkg.Imports(), importElem)
|
|
}
|
|
|
|
sliceutil.AddIfNotFound(def.Fields(), protoField{
|
|
name: field.Name,
|
|
typ: typText,
|
|
num: field.Num,
|
|
goType: field.TypeData.Type().String(),
|
|
})
|
|
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func mustFormatTypeName(fieldTypePkg string, fieldType string, declPkg string) (string, string) {
|
|
importPath, name := formatTypeName(fieldTypePkg, fieldType, declPkg)
|
|
if name == "" {
|
|
panic(fmt.Errorf("unknown type %s.%s", fieldType, fieldTypePkg))
|
|
}
|
|
|
|
return importPath, name
|
|
}
|
|
|
|
func formatTypeName(fieldTypePkg string, fieldType string, declPkg string) (string, string) {
|
|
if fieldTypePkg == declPkg {
|
|
return "", fieldType
|
|
}
|
|
|
|
type typeData struct {
|
|
pkg string
|
|
name string
|
|
}
|
|
|
|
td := typeData{
|
|
name: fieldType,
|
|
pkg: fieldTypePkg,
|
|
}
|
|
|
|
const commoProto = "common/common.proto"
|
|
|
|
switch td {
|
|
case typeData{"time", "Time"}:
|
|
return "google/protobuf/timestamp.proto", "google.protobuf.Timestamp"
|
|
case typeData{"net/url", "URL"}:
|
|
return commoProto, "common.URL"
|
|
case typeData{"net/netip", "Prefix"}:
|
|
return commoProto, "common.NetIPPrefix"
|
|
case typeData{"net/netip", "AddrPort"}:
|
|
return commoProto, "common.NetIPPort"
|
|
case typeData{"net/netip", "Addr"}:
|
|
return commoProto, "common.NetIP"
|
|
case typeData{"github.com/opencontainers/runtime-spec/specs-go", "Mount"}:
|
|
return "resource/definitions/proto/proto.proto", "talos.resource.definitions.proto.Mount"
|
|
case typeData{"github.com/siderolabs/crypto/x509", "PEMEncodedCertificateAndKey"}:
|
|
return commoProto, "common.PEMEncodedCertificateAndKey"
|
|
case typeData{"github.com/siderolabs/crypto/x509", "PEMEncodedKey"}:
|
|
return commoProto, "common.PEMEncodedKey"
|
|
default:
|
|
return "", ""
|
|
}
|
|
}
|
|
|
|
func mustFormatBasicTypeName(fieldTypePkg string, fieldType string) (string, string) {
|
|
if fieldTypePkg == "" {
|
|
return "", getProtoBasicName(fieldType)
|
|
}
|
|
|
|
importPath, fullName := formatBasicTypeName(fieldTypePkg, fieldType)
|
|
if fullName == "" {
|
|
panic(fmt.Errorf("unknown type %s.%s", fieldTypePkg, fieldType))
|
|
}
|
|
|
|
return importPath, fullName
|
|
}
|
|
|
|
// IsSupportedExternalType checks if external type is supported.
|
|
func IsSupportedExternalType(typ types.ExternalType) bool {
|
|
if _, name := formatBasicTypeName(typ.Pkg, typ.Name); name != "" {
|
|
return true
|
|
}
|
|
|
|
if _, name := formatTypeName(typ.Pkg, typ.Name, ""); name != "" {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
//nolint:gocyclo,cyclop
|
|
func formatBasicTypeName(typPkg string, typ string) (importPath, fullName string) {
|
|
type typeData struct {
|
|
pkg string
|
|
name string
|
|
}
|
|
|
|
td := typeData{
|
|
name: typ,
|
|
pkg: typPkg,
|
|
}
|
|
|
|
const enumsProto = "resource/definitions/enums/enums.proto"
|
|
|
|
switch td {
|
|
case typeData{"time", "Duration"}:
|
|
return "google/protobuf/duration.proto", "google.protobuf.Duration"
|
|
case typeData{"io/fs", "FileMode"}:
|
|
return "", "uint32" //nolint:goconst
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/config/types/v1alpha1/machine", "Type"}:
|
|
return enumsProto, "talos.resource.definitions.enums.MachineType"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/resources/kubespan", "PeerState"}:
|
|
return enumsProto, "talos.resource.definitions.enums.KubespanPeerState"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/resources/network", "ConfigLayer"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NetworkConfigLayer"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/resources/network", "Operator"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NetworkOperator"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "Family"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersFamily"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "AddressFlags"}:
|
|
return "", "uint32"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "Scope"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersScope"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "ADSelect"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersADSelect"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "ARPAllTargets"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersARPAllTargets"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "ARPValidate"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersARPValidate"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "FailOverMAC"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersFailOverMAC"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "BondXmitHashPolicy"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersBondXmitHashPolicy"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "LACPRate"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersLACPRate"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "BondMode"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersBondMode"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "PrimaryReselect"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersPrimaryReselect"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "LinkType"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersLinkType"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "Duplex"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersDuplex"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "LinkFlags"}:
|
|
return "", "uint32"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "OperationalState"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersOperationalState"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "Port"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersPort"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "RouteFlags"}:
|
|
return "", "uint32"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "RouteProtocol"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersRouteProtocol"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "RoutingTable"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersRoutingTable"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "RouteType"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersRouteType"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/nethelpers", "VLANProtocol"}:
|
|
return enumsProto, "talos.resource.definitions.enums.NethelpersVLANProtocol"
|
|
case typeData{"github.com/siderolabs/talos/pkg/machinery/resources/runtime", "MachineStage"}:
|
|
return enumsProto, "talos.resource.definitions.enums.RuntimeMachineStage"
|
|
default:
|
|
return "", ""
|
|
}
|
|
}
|
|
|
|
//nolint:gocyclo
|
|
func getProtoBasicName(typ string) string {
|
|
switch typ {
|
|
case "bool":
|
|
return "bool"
|
|
case "int8", "int16":
|
|
return "fixed32"
|
|
case "int32":
|
|
return "int32"
|
|
case "int64", "int":
|
|
return "int64"
|
|
case "byte", "uint8", "uint16":
|
|
return "fixed32"
|
|
case "uint32":
|
|
return "uint32"
|
|
case "uint64", "uint":
|
|
return "uint64"
|
|
case "float32":
|
|
return "float"
|
|
case "float64":
|
|
return "double"
|
|
case "string":
|
|
return "string"
|
|
default:
|
|
panic(fmt.Sprintf("unknown type %s", typ))
|
|
}
|
|
}
|
|
|
|
// ToSnakeCase converts a string to snake case.
|
|
func ToSnakeCase(str string) string {
|
|
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
|
|
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
|
|
snake = strings.ToLower(snake)
|
|
|
|
// special case for "SomethingsIps"
|
|
if strings.HasSuffix(snake, "_i_ps") {
|
|
snake = strings.TrimSuffix(snake, "_i_ps") + "_ips"
|
|
}
|
|
|
|
return snake
|
|
}
|
|
|
|
var (
|
|
matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
|
|
matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
|
|
)
|