Add support for fallthrough to the grpc plugin (#7359)

Fixes: https://github.com/coredns/coredns/issues/7358

Signed-off-by: Blake Barnett <bbarnett@groq.com>
This commit is contained in:
blakebarnett 2025-06-06 04:58:17 -07:00 committed by GitHub
parent 0eb5542035
commit 6cba588951
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 118 additions and 1 deletions

View File

@ -33,6 +33,7 @@ grpc FROM TO... {
tls CERT KEY CA
tls_servername NAME
policy random|round_robin|sequential
fallthrough [ZONES...]
}
~~~
@ -54,6 +55,12 @@ grpc FROM TO... {
but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1
(Cloudflare) will not work.
* `policy` specifies the policy to use for selecting upstream servers. The default is `random`.
* `fallthrough` **[ZONES...]** If a query results in NXDOMAIN from the gRPC backend, pass the request
to the next plugin instead of returning the NXDOMAIN response. This is useful when the gRPC backend
is authoritative for a zone but should not return authoritative NXDOMAIN responses for queries that
don't actually belong to that zone (e.g., search path queries). If **[ZONES...]** is omitted, then
fallthrough happens for all zones. If specific zones are listed, then only queries for those zones
will be subject to fallthrough.
Also note the TLS config is "global" for the whole grpc proxy if you need a different
`tls-name` for different upstreams you're out of luck.
@ -137,6 +144,17 @@ Forward requests to a local upstream listening on a Unix domain socket.
}
~~~
Proxy requests for `example.org.` to a gRPC backend, but fallthrough to the next plugin for NXDOMAIN responses to handle search path queries correctly.
~~~ corefile
example.org {
grpc . 127.0.0.1:9005 {
fallthrough
}
forward . 8.8.8.8
}
~~~
## Bugs
The TLS config is global for the whole grpc proxy if you need a different `tls_servername` for

View File

@ -8,6 +8,7 @@ import (
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/debug"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
@ -26,6 +27,7 @@ type GRPC struct {
tlsConfig *tls.Config
tlsServerName string
Fall fall.F
Next plugin.Handler
}
@ -33,7 +35,11 @@ type GRPC struct {
func (g *GRPC) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
if !g.match(state) {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
if g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
// No next plugin, return SERVFAIL
return dns.RcodeServerFailure, nil
}
var (
@ -84,16 +90,33 @@ func (g *GRPC) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (
return 0, nil
}
// Check if we should fallthrough on NXDOMAIN responses
if ret.Rcode == dns.RcodeNameError && g.Fall.Through(state.Name()) {
if g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
// No next plugin to fallthrough to, return the NXDOMAIN response
}
w.WriteMsg(ret)
return 0, nil
}
// SERVFAIL if all healthy proxys returned errors.
if err != nil {
// If fallthrough is enabled, try the next plugin instead of returning SERVFAIL
if g.Fall.Through(state.Name()) && g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
// just return the last error received
return dns.RcodeServerFailure, err
}
// If fallthrough is enabled, try the next plugin instead of returning SERVFAIL
if g.Fall.Through(state.Name()) && g.Next != nil {
return plugin.NextOrFailure(g.Name(), g.Next, ctx, w, r)
}
return dns.RcodeServerFailure, ErrNoHealthy
}

View File

@ -3,10 +3,12 @@ package grpc
import (
"context"
"errors"
"strings"
"testing"
"github.com/coredns/coredns/pb"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/fall"
"github.com/coredns/coredns/plugin/test"
"github.com/miekg/dns"
@ -73,3 +75,30 @@ func TestGRPC(t *testing.T) {
})
}
}
// Test that fallthrough works correctly when there's no next plugin
func TestGRPCFallthroughNoNext(t *testing.T) {
g := newGRPC() // Use the constructor to properly initialize
g.Fall = fall.Root // Enable fallthrough for all zones
g.Next = nil // No next plugin
g.from = "."
// Create a test request
r := new(dns.Msg)
r.SetQuestion("test.example.org.", dns.TypeA)
w := &test.ResponseWriter{}
// Should return SERVFAIL since no backends are configured and no next plugin
rcode, err := g.ServeDNS(context.Background(), w, r)
// Should not return the "no next plugin found" error
if err != nil && strings.Contains(err.Error(), "no next plugin found") {
t.Errorf("Expected no 'no next plugin found' error, got: %v", err)
}
// Should return SERVFAIL
if rcode != dns.RcodeServerFailure {
t.Errorf("Expected SERVFAIL when no backends and no next plugin, got: %d", rcode)
}
}

View File

@ -141,6 +141,8 @@ func parseBlock(c *caddy.Controller, g *GRPC) error {
default:
return c.Errf("unknown policy '%s'", x)
}
case "fallthrough":
g.Fall.SetZonesFromArgs(c.RemainingArgs())
default:
if c.Val() != "}" {
return c.Errf("unknown property '%s'", c.Val())

View File

@ -7,6 +7,7 @@ import (
"testing"
"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/fall"
)
func TestSetup(t *testing.T) {
@ -152,3 +153,47 @@ nameserver 10.10.255.253`), 0666); err != nil {
}
}
}
func TestSetupFallthrough(t *testing.T) {
tests := []struct {
input string
shouldErr bool
expectedFallthrough fall.F
expectedErr string
}{
// positive cases
{`grpc . 127.0.0.1 {
fallthrough
}`, false, fall.Root, ""},
{`grpc . 127.0.0.1 {
fallthrough example.org
}`, false, fall.F{Zones: []string{"example.org."}}, ""},
{`grpc . 127.0.0.1 {
fallthrough example.org example.com
}`, false, fall.F{Zones: []string{"example.org.", "example.com."}}, ""},
{`grpc . 127.0.0.1`, false, fall.Zero, ""},
}
for i, test := range tests {
c := caddy.NewTestController("dns", test.input)
g, err := parseGRPC(c)
if test.shouldErr && err == nil {
t.Errorf("Test %d: expected error but found none for input %s", i, test.input)
}
if err != nil {
if !test.shouldErr {
t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err)
}
if !strings.Contains(err.Error(), test.expectedErr) {
t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input)
}
}
if !test.shouldErr && !g.Fall.Equal(test.expectedFallthrough) {
t.Errorf("Test %d: expected fallthrough %+v, got %+v", i, test.expectedFallthrough, g.Fall)
}
}
}