talos/hack/structprotogen/proto/proto.go
Andrey Smirnov f62d17125b
chore: update crypto to use new import path siderolabs/crypto
No functional changes in this PR, just updating import paths.

Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
2022-09-07 23:02:50 +04:00

533 lines
16 KiB
Go

// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// Package proto contains the protobuf generation logic.
package proto
//nolint:gci
import (
"fmt"
"io"
"regexp"
"strings"
"gopkg.in/typ.v4/slices"
"github.com/talos-systems/structprotogen/sliceutil"
"github.com/talos-systems/structprotogen/types"
)
// Pkg represents a protobuf package.
type Pkg struct {
Name string
GoPkg string
isInit bool
protoDefs slices.Sorted[*protoDef]
imports slices.Sorted[string]
}
func protoPkgsCmp(left, right *Pkg) int {
return strings.Compare(left.Name, right.Name)
}
func (p *Pkg) init() {
if !p.isInit {
p.protoDefs = slices.NewSortedCompare([]*protoDef{}, protoDefCmp)
p.imports = slices.NewSortedCompare([]string{}, strings.Compare)
p.isInit = true
}
}
// Defs returns the list of definitions.
func (p *Pkg) Defs() *slices.Sorted[*protoDef] {
p.init()
return &p.protoDefs
}
// Imports returns the list of imports.
func (p *Pkg) Imports() *slices.Sorted[string] {
p.init()
return &p.imports
}
// WriteDebug is like Format, but writes additional debug info.
func (p *Pkg) WriteDebug(w io.Writer) {
pkgName := p.Name
fmt.Fprint(w, "syntax = \"proto3\";\n\n")
fmt.Fprintf(w, "package talos.resource.definitions.%s; // %s\n\n", p.Name, p.GoPkg)
fmt.Fprintf(w, "option go_package = \"github.com/talos-systems/talos/pkg/machinery/api/resource/definitions/%s\";\n\n", pkgName) // TODO: insert proper path
if p.imports.Len() > 0 {
for i := 0; i < p.imports.Len(); i++ {
importPath := p.imports.Get(i)
if !strings.ContainsRune(importPath, '.') {
importPath = "talos.resource.definitions." + importPath
}
fmt.Fprintf(w, "import \"%s\";\n", importPath)
}
fmt.Fprintln(w, ``)
}
for i := 0; i < p.protoDefs.Len(); i++ {
p.protoDefs.Get(i).WriteDebug(w)
fmt.Fprintln(w)
}
}
// Format formats the protobuf data.
func (p *Pkg) Format(w io.Writer) {
pkgName := p.Name
fmt.Fprint(w, "syntax = \"proto3\";\n\n")
fmt.Fprintf(w, "package talos.resource.definitions.%s;\n\n", p.Name)
fmt.Fprintf(w, "option go_package = \"github.com/talos-systems/talos/pkg/machinery/api/resource/definitions/%s\";\n\n", pkgName) // TODO: insert proper path
if p.imports.Len() > 0 {
for i := 0; i < p.imports.Len(); i++ {
importPath := p.imports.Get(i)
if !strings.ContainsRune(importPath, '.') {
importPath = "talos.resource.definitions." + importPath
}
fmt.Fprintf(w, "import \"%s\";\n", importPath)
}
fmt.Fprintln(w, ``)
}
for i := 0; i < p.protoDefs.Len(); i++ {
p.protoDefs.Get(i).Format(w)
fmt.Fprintln(w)
}
}
type protoDef struct {
name string
goPkg string
comments []string
isInit bool
fields slices.Sorted[protoField]
}
func protoDefCmp(left, right *protoDef) int {
return strings.Compare(left.name, right.name)
}
func (p *protoDef) init() {
if !p.isInit {
p.fields = slices.NewSortedCompare([]protoField{}, protoFieldCmp)
p.isInit = true
}
}
func (p *protoDef) Fields() *slices.Sorted[protoField] {
p.init()
return &p.fields
}
func (p *protoDef) WriteDebug(w io.Writer) {
for _, comment := range p.comments {
fmt.Fprintf(w, "%s\n", comment)
}
fmt.Fprintf(w, "message %s { //%s.%s\n", p.name, p.goPkg, p.name)
for i := 0; i < p.fields.Len(); i++ {
fmt.Fprintf(w, " ")
p.fields.Get(i).WriteDebug(w)
}
fmt.Fprintln(w, "}")
}
func (p *protoDef) Format(w io.Writer) {
for _, comment := range p.comments {
fmt.Fprintf(w, "%s\n", comment)
}
fmt.Fprintf(w, "message %s {\n", p.name)
for i := 0; i < p.fields.Len(); i++ {
fmt.Fprintf(w, " ")
p.fields.Get(i).Format(w)
}
fmt.Fprintln(w, "}")
}
type protoField struct {
name string
typ string
num int
goType string
}
func protoFieldCmp(left, right protoField) int {
if left.num == 0 {
panic(fmt.Errorf("left field '%s' has no number", left.name))
}
if right.num == 0 {
panic(fmt.Errorf("right field '%s' has no number", right.name))
}
switch {
case left.num < right.num:
return -1
case left.num > right.num:
return 1
default:
return 0
}
}
func (pf protoField) WriteDebug(w io.Writer) {
fmt.Fprintf(w, "%s %s = %d; // %s \n", pf.typ, ToSnakeCase(pf.name), pf.num, pf.goType)
}
func (pf protoField) Format(w io.Writer) {
fmt.Fprintf(w, "%s %s = %d;\n", pf.typ, ToSnakeCase(pf.name), pf.num)
}
// PrepareProtoData prepares the data for the protobuf generation.
//
//nolint:gocyclo,cyclop
func PrepareProtoData(pkgsTypes slices.Sorted[*types.Type]) slices.Sorted[*Pkg] {
result := slices.NewSortedCompare([]*Pkg{}, protoPkgsCmp)
for i := 0; i < pkgsTypes.Len(); i++ {
pkgType := pkgsTypes.Get(i)
protoPkg := sliceutil.GetOrAdd(&result, &Pkg{
Name: pkgType.PkgName(),
GoPkg: pkgType.Pkg,
})
def := sliceutil.GetOrAdd(protoPkg.Defs(), &protoDef{
name: pkgType.Name,
goPkg: pkgType.Pkg,
comments: pkgType.Comments,
})
for j := 0; j < pkgType.Fields().Len(); j++ {
field := pkgType.Fields().Get(j)
fieldTypeData := types.TypeInfo(field.TypeData.Type())
if fieldTyp, ok := types.MatchTypeData[types.Complex](fieldTypeData); ok {
importName, typeName := mustFormatTypeName(fieldTyp.Pkg, fieldTyp.Name, pkgType.Pkg)
if importName != "" {
sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
}
sliceutil.AddIfNotFound(def.Fields(), protoField{
name: field.Name,
typ: typeName,
num: field.Num,
goType: field.TypeData.Type().String(),
})
continue
}
if fieldTyp, ok := types.MatchTypeData[types.Basic](fieldTypeData); ok {
importName, typeName := mustFormatBasicTypeName(fieldTyp.Pkg, fieldTyp.Name)
if importName != "" {
sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
}
sliceutil.AddIfNotFound(def.Fields(), protoField{
name: field.Name,
typ: typeName,
num: field.Num,
goType: field.TypeData.Type().String(),
})
continue
}
if fieldTyp, ok := types.MatchTypeData[types.Slice](fieldTypeData); ok {
var importName, typeName string
switch {
case fieldTyp.Pkg == "" && fieldTyp.Name == "byte" && fieldTyp.Is2DSlice: //nolint:goconst
typeName = "repeated bytes"
case fieldTyp.Pkg == "" && fieldTyp.Name == "byte":
typeName = "bytes"
case fieldTyp.Pkg == "":
typeName = fmt.Sprintf("repeated %s", getProtoBasicName(fieldTyp.Name))
default:
importName, typeName = mustFormatTypeName(fieldTyp.Pkg, fieldTyp.Name, pkgType.Pkg)
typeName = fmt.Sprintf("repeated %s", typeName)
}
if importName != "" {
sliceutil.AddIfNotFound(protoPkg.Imports(), importName)
}
sliceutil.AddIfNotFound(def.Fields(), protoField{
name: field.Name,
typ: typeName,
num: field.Num,
goType: field.TypeData.Type().String(),
})
continue
}
if fieldTyp, ok := types.MatchTypeData[types.Map](fieldTypeData); ok {
// key cannot be anything but a basic type
importKeyName, keyTypeName := mustFormatBasicTypeName(fieldTyp.KeyTypePkg, fieldTyp.KeyTypeName)
if importKeyName != "" {
panic(fmt.Errorf("map key type '%s.%s' is not basic type", fieldTyp.KeyTypePkg, fieldTyp.KeyTypeName))
}
var (
typText string
importElem string
)
switch {
case fieldTyp.ElemTypeName == "interface{}": // handle map[key]interface{}
importElem = "google/protobuf/struct.proto"
typText = "google.protobuf.Struct"
case fieldTyp.ElemTypePkg == "":
elemTypeName := getProtoBasicName(fieldTyp.ElemTypeName)
typText = fmt.Sprintf("map<%s, %s>", keyTypeName, elemTypeName)
case fieldTyp.ElemTypePkg == pkgType.Pkg:
var elemTypeName string
importElem, elemTypeName = mustFormatTypeName(fieldTyp.ElemTypePkg, fieldTyp.ElemTypeName, pkgType.Pkg)
typText = fmt.Sprintf("map<%s, %s>", keyTypeName, elemTypeName)
default:
panic(fmt.Errorf("map value type '%s.%s' is not known type", fieldTyp.ElemTypePkg, fieldTyp.ElemTypeName))
}
if importElem != "" {
sliceutil.AddIfNotFound(protoPkg.Imports(), importElem)
}
sliceutil.AddIfNotFound(def.Fields(), protoField{
name: field.Name,
typ: typText,
num: field.Num,
goType: field.TypeData.Type().String(),
})
continue
}
}
}
return result
}
func mustFormatTypeName(fieldTypePkg string, fieldType string, declPkg string) (string, string) {
importPath, name := formatTypeName(fieldTypePkg, fieldType, declPkg)
if name == "" {
panic(fmt.Errorf("unknown type %s.%s", fieldType, fieldTypePkg))
}
return importPath, name
}
func formatTypeName(fieldTypePkg string, fieldType string, declPkg string) (string, string) {
if fieldTypePkg == declPkg {
return "", fieldType
}
type typeData struct {
pkg string
name string
}
td := typeData{
name: fieldType,
pkg: fieldTypePkg,
}
const commoProto = "common/common.proto"
switch td {
case typeData{"time", "Time"}:
return "google/protobuf/timestamp.proto", "google.protobuf.Timestamp"
case typeData{"net/url", "URL"}:
return commoProto, "common.URL"
case typeData{"net/netip", "Prefix"}:
return commoProto, "common.NetIPPrefix"
case typeData{"net/netip", "AddrPort"}:
return commoProto, "common.NetIPPort"
case typeData{"net/netip", "Addr"}:
return commoProto, "common.NetIP"
case typeData{"github.com/opencontainers/runtime-spec/specs-go", "Mount"}:
return "resource/definitions/proto/proto.proto", "talos.resource.definitions.proto.Mount"
case typeData{"github.com/siderolabs/crypto/x509", "PEMEncodedCertificateAndKey"}:
return commoProto, "common.PEMEncodedCertificateAndKey"
case typeData{"github.com/siderolabs/crypto/x509", "PEMEncodedKey"}:
return commoProto, "common.PEMEncodedKey"
default:
return "", ""
}
}
func mustFormatBasicTypeName(fieldTypePkg string, fieldType string) (string, string) {
if fieldTypePkg == "" {
return "", getProtoBasicName(fieldType)
}
importPath, fullName := formatBasicTypeName(fieldTypePkg, fieldType)
if fullName == "" {
panic(fmt.Errorf("unknown type %s.%s", fieldTypePkg, fieldType))
}
return importPath, fullName
}
// IsSupportedExternalType checks if external type is supported.
func IsSupportedExternalType(typ types.ExternalType) bool {
if _, name := formatBasicTypeName(typ.Pkg, typ.Name); name != "" {
return true
}
if _, name := formatTypeName(typ.Pkg, typ.Name, ""); name != "" {
return true
}
return false
}
//nolint:gocyclo,cyclop
func formatBasicTypeName(typPkg string, typ string) (importPath, fullName string) {
type typeData struct {
pkg string
name string
}
td := typeData{
name: typ,
pkg: typPkg,
}
const enumsProto = "resource/definitions/enums/enums.proto"
switch td {
case typeData{"time", "Duration"}:
return "google/protobuf/duration.proto", "google.protobuf.Duration"
case typeData{"io/fs", "FileMode"}:
return "", "uint32" //nolint:goconst
case typeData{"github.com/talos-systems/talos/pkg/machinery/config/types/v1alpha1/machine", "Type"}:
return enumsProto, "talos.resource.definitions.enums.MachineType"
case typeData{"github.com/talos-systems/talos/pkg/machinery/resources/kubespan", "PeerState"}:
return enumsProto, "talos.resource.definitions.enums.KubespanPeerState"
case typeData{"github.com/talos-systems/talos/pkg/machinery/resources/network", "ConfigLayer"}:
return enumsProto, "talos.resource.definitions.enums.NetworkConfigLayer"
case typeData{"github.com/talos-systems/talos/pkg/machinery/resources/network", "Operator"}:
return enumsProto, "talos.resource.definitions.enums.NetworkOperator"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "Family"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersFamily"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "AddressFlags"}:
return "", "uint32"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "Scope"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersScope"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "ADSelect"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersADSelect"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "ARPAllTargets"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersARPAllTargets"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "ARPValidate"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersARPValidate"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "FailOverMAC"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersFailOverMAC"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "BondXmitHashPolicy"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersBondXmitHashPolicy"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "LACPRate"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersLACPRate"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "BondMode"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersBondMode"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "PrimaryReselect"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersPrimaryReselect"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "LinkType"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersLinkType"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "Duplex"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersDuplex"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "LinkFlags"}:
return "", "uint32"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "OperationalState"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersOperationalState"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "Port"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersPort"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "RouteFlags"}:
return "", "uint32"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "RouteProtocol"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersRouteProtocol"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "RoutingTable"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersRoutingTable"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "RouteType"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersRouteType"
case typeData{"github.com/talos-systems/talos/pkg/machinery/nethelpers", "VLANProtocol"}:
return enumsProto, "talos.resource.definitions.enums.NethelpersVLANProtocol"
case typeData{"github.com/talos-systems/talos/pkg/machinery/resources/runtime", "MachineStage"}:
return enumsProto, "talos.resource.definitions.enums.RuntimeMachineStage"
default:
return "", ""
}
}
//nolint:gocyclo
func getProtoBasicName(typ string) string {
switch typ {
case "bool":
return "bool"
case "int8", "int16":
return "fixed32"
case "int32":
return "int32"
case "int64", "int":
return "int64"
case "byte", "uint8", "uint16":
return "fixed32"
case "uint32":
return "uint32"
case "uint64", "uint":
return "uint64"
case "float32":
return "float"
case "float64":
return "double"
case "string":
return "string"
default:
panic(fmt.Sprintf("unknown type %s", typ))
}
}
// ToSnakeCase converts a string to snake case.
func ToSnakeCase(str string) string {
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
snake = strings.ToLower(snake)
// special case for "SomethingsIps"
if strings.HasSuffix(snake, "_i_ps") {
snake = strings.TrimSuffix(snake, "_i_ps") + "_ips"
}
return snake
}
var (
matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
)