lint: enable protogetter linter (#7336)

Enable protogetter in golangci config and update all protobuf field
access to use getter methods instead of direct field access.
Getter methods provide safer nil pointer handling and return
appropriate default values, following protobuf best practices.

Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
This commit is contained in:
Ville Vesilehto 2025-05-31 01:29:32 +03:00 committed by GitHub
parent 7ecb5011b2
commit 53e9681a39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 46 additions and 31 deletions

View File

@ -8,6 +8,7 @@ linters:
- ineffassign - ineffassign
- intrange - intrange
- nolintlint - nolintlint
- protogetter
- staticcheck - staticcheck
- unconvert - unconvert
- unused - unused

View File

@ -123,7 +123,7 @@ func (s *ServergRPC) Stop() (err error) {
// back to the client as a protobuf. // back to the client as a protobuf.
func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) { func (s *ServergRPC) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) {
msg := new(dns.Msg) msg := new(dns.Msg)
err := msg.Unpack(in.Msg) err := msg.Unpack(in.GetMsg())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -209,13 +209,13 @@ func TestServergRPC_Query(t *testing.T) {
t.Errorf("Query() failed: %v", err) t.Errorf("Query() failed: %v", err)
} }
if len(response.Msg) == 0 { if len(response.GetMsg()) == 0 {
t.Error("Query() returned empty message") t.Error("Query() returned empty message")
} }
// Verify the response can be unpacked // Verify the response can be unpacked
respMsg := new(dns.Msg) respMsg := new(dns.Msg)
err = respMsg.Unpack(response.Msg) err = respMsg.Unpack(response.GetMsg())
if err != nil { if err != nil {
t.Errorf("Failed to unpack response message: %v", err) t.Errorf("Failed to unpack response message: %v", err)
} }

View File

@ -47,23 +47,37 @@ func (w *writer) Dnstap(e *tap.Dnstap) {
w.t.Error("Message not expected") w.t.Error("Message not expected")
} }
ex := w.queue[0].Message ex := w.queue[0].GetMessage()
got := e.Message got := e.GetMessage()
if string(ex.QueryAddress) != string(got.QueryAddress) { eaddr := string(ex.GetQueryAddress())
w.t.Errorf("Expected source address %s, got %s", ex.QueryAddress, got.QueryAddress) gaddr := string(got.GetQueryAddress())
if eaddr != gaddr {
w.t.Errorf("Expected source address %s, got %s", eaddr, gaddr)
} }
if string(ex.ResponseAddress) != string(got.ResponseAddress) {
w.t.Errorf("Expected response address %s, got %s", ex.ResponseAddress, got.ResponseAddress) eraddr := string(ex.GetResponseAddress())
graddr := string(got.GetResponseAddress())
if eraddr != graddr {
w.t.Errorf("Expected response address %s, got %s", eraddr, graddr)
} }
if *ex.QueryPort != *got.QueryPort {
w.t.Errorf("Expected port %d, got %d", *ex.QueryPort, *got.QueryPort) ep := ex.GetQueryPort()
gp := got.GetQueryPort()
if ep != gp {
w.t.Errorf("Expected port %d, got %d", ep, gp)
} }
if *ex.SocketFamily != *got.SocketFamily {
w.t.Errorf("Expected socket family %d, got %d", *ex.SocketFamily, *got.SocketFamily) ef := ex.GetSocketFamily()
sf := got.GetSocketFamily()
if ef != sf {
w.t.Errorf("Expected socket family %d, got %d", ef, sf)
} }
if string(w.queue[0].Extra) != string(e.Extra) {
w.t.Errorf("Expected extra %s, got %s", w.queue[0].Extra, e.Extra) eext := string(w.queue[0].GetExtra())
gext := string(e.GetExtra())
if eext != gext {
w.t.Errorf("Expected extra %s, got %s", eext, gext)
} }
w.queue = w.queue[1:] w.queue = w.queue[1:]
} }
@ -80,23 +94,23 @@ func TestDnstap(t *testing.T) {
tapq := &tap.Dnstap{ tapq := &tap.Dnstap{
Message: testMessage(), Message: testMessage(),
} }
msg.SetType(tapq.Message, tap.Message_CLIENT_QUERY) msg.SetType(tapq.GetMessage(), tap.Message_CLIENT_QUERY)
tapr := &tap.Dnstap{ tapr := &tap.Dnstap{
Message: testMessage(), Message: testMessage(),
} }
msg.SetType(tapr.Message, tap.Message_CLIENT_RESPONSE) msg.SetType(tapr.GetMessage(), tap.Message_CLIENT_RESPONSE)
testCase(t, tapq, tapr, q, r, "") testCase(t, tapq, tapr, q, r, "")
tapq_with_extra := &tap.Dnstap{ tapq_with_extra := &tap.Dnstap{
Message: testMessage(), // leave type unset for deepEqual Message: testMessage(), // leave type unset for deepEqual
Extra: []byte("extra_field_MetadataValue_A_example.org._IN_udp_29_10.240.0.1_40212_127.0.0.1"), Extra: []byte("extra_field_MetadataValue_A_example.org._IN_udp_29_10.240.0.1_40212_127.0.0.1"),
} }
msg.SetType(tapq_with_extra.Message, tap.Message_CLIENT_QUERY) msg.SetType(tapq_with_extra.GetMessage(), tap.Message_CLIENT_QUERY)
tapr_with_extra := &tap.Dnstap{ tapr_with_extra := &tap.Dnstap{
Message: testMessage(), Message: testMessage(),
Extra: []byte("extra_field_MetadataValue_A_example.org._IN_udp_29_10.240.0.1_40212_127.0.0.1"), Extra: []byte("extra_field_MetadataValue_A_example.org._IN_udp_29_10.240.0.1_40212_127.0.0.1"),
} }
msg.SetType(tapr_with_extra.Message, tap.Message_CLIENT_RESPONSE) msg.SetType(tapr_with_extra.GetMessage(), tap.Message_CLIENT_RESPONSE)
extraFormat := "extra_field_{/metadata/test}_{type}_{name}_{class}_{proto}_{size}_{remote}_{port}_{local}" extraFormat := "extra_field_{/metadata/test}_{type}_{name}_{class}_{proto}_{size}_{remote}_{port}_{local}"
testCase(t, tapq_with_extra, tapr_with_extra, q, r, extraFormat) testCase(t, tapq_with_extra, tapr_with_extra, q, r, extraFormat)
} }
@ -120,7 +134,7 @@ func TestTapMessage(t *testing.T) {
// extra field would not be replaced, since TapMessage won't pass context // extra field would not be replaced, since TapMessage won't pass context
Extra: []byte(extraFormat), Extra: []byte(extraFormat),
} }
msg.SetType(tapq.Message, tap.Message_CLIENT_QUERY) msg.SetType(tapq.GetMessage(), tap.Message_CLIENT_QUERY)
w := writer{t: t} w := writer{t: t}
w.queue = append(w.queue, tapq) w.queue = append(w.queue, tapq)
@ -132,5 +146,5 @@ func TestTapMessage(t *testing.T) {
io: &w, io: &w,
ExtraFormat: extraFormat, ExtraFormat: extraFormat,
} }
h.TapMessage(tapq.Message) h.TapMessage(tapq.GetMessage())
} }

View File

@ -65,7 +65,7 @@ func (p *Proxy) query(ctx context.Context, req *dns.Msg) (*dns.Msg, error) {
return nil, err return nil, err
} }
ret := new(dns.Msg) ret := new(dns.Msg)
if err := ret.Unpack(reply.Msg); err != nil { if err := ret.Unpack(reply.GetMsg()); err != nil {
return nil, err return nil, err
} }

View File

@ -150,9 +150,9 @@ func newMetricFamily(dtoMF *dto.MetricFamily) *MetricFamily {
Name: dtoMF.GetName(), Name: dtoMF.GetName(),
Help: dtoMF.GetHelp(), Help: dtoMF.GetHelp(),
Type: dtoMF.GetType().String(), Type: dtoMF.GetType().String(),
Metrics: make([]interface{}, len(dtoMF.Metric)), Metrics: make([]interface{}, len(dtoMF.GetMetric())),
} }
for i, m := range dtoMF.Metric { for i, m := range dtoMF.GetMetric() {
if dtoMF.GetType() == dto.MetricType_SUMMARY { if dtoMF.GetType() == dto.MetricType_SUMMARY {
mf.Metrics[i] = summary{ mf.Metrics[i] = summary{
Labels: makeLabels(m), Labels: makeLabels(m),
@ -178,13 +178,13 @@ func newMetricFamily(dtoMF *dto.MetricFamily) *MetricFamily {
} }
func value(m *dto.Metric) float64 { func value(m *dto.Metric) float64 {
if m.Gauge != nil { if m.GetGauge() != nil {
return m.GetGauge().GetValue() return m.GetGauge().GetValue()
} }
if m.Counter != nil { if m.GetCounter() != nil {
return m.GetCounter().GetValue() return m.GetCounter().GetValue()
} }
if m.Untyped != nil { if m.GetUntyped() != nil {
return m.GetUntyped().GetValue() return m.GetUntyped().GetValue()
} }
return 0. return 0.
@ -192,7 +192,7 @@ func value(m *dto.Metric) float64 {
func makeLabels(m *dto.Metric) map[string]string { func makeLabels(m *dto.Metric) map[string]string {
result := map[string]string{} result := map[string]string{}
for _, lp := range m.Label { for _, lp := range m.GetLabel() {
result[lp.GetName()] = lp.GetValue() result[lp.GetName()] = lp.GetValue()
} }
return result return result
@ -200,7 +200,7 @@ func makeLabels(m *dto.Metric) map[string]string {
func makeQuantiles(m *dto.Metric) map[string]string { func makeQuantiles(m *dto.Metric) map[string]string {
result := map[string]string{} result := map[string]string{}
for _, q := range m.GetSummary().Quantile { for _, q := range m.GetSummary().GetQuantile() {
result[fmt.Sprint(q.GetQuantile())] = fmt.Sprint(q.GetValue()) result[fmt.Sprint(q.GetQuantile())] = fmt.Sprint(q.GetValue())
} }
return result return result
@ -208,7 +208,7 @@ func makeQuantiles(m *dto.Metric) map[string]string {
func makeBuckets(m *dto.Metric) map[string]string { func makeBuckets(m *dto.Metric) map[string]string {
result := map[string]string{} result := map[string]string{}
for _, b := range m.GetHistogram().Bucket { for _, b := range m.GetHistogram().GetBucket() {
result[fmt.Sprint(b.GetUpperBound())] = fmt.Sprint(b.GetCumulativeCount()) result[fmt.Sprint(b.GetUpperBound())] = fmt.Sprint(b.GetCumulativeCount())
} }
return result return result

View File

@ -40,7 +40,7 @@ func TestGrpc(t *testing.T) {
} }
d := new(dns.Msg) d := new(dns.Msg)
err = d.Unpack(reply.Msg) err = d.Unpack(reply.GetMsg())
if err != nil { if err != nil {
t.Errorf("Expected no error but got: %s", err) t.Errorf("Expected no error but got: %s", err)
} }