diff --git a/core/dnsserver/quic.go b/core/dnsserver/quic.go index 5c2890a72..dbe9552a8 100644 --- a/core/dnsserver/quic.go +++ b/core/dnsserver/quic.go @@ -2,6 +2,7 @@ package dnsserver import ( "encoding/binary" + "errors" "net" "github.com/miekg/dns" @@ -11,11 +12,14 @@ import ( type DoQWriter struct { localAddr net.Addr remoteAddr net.Addr - stream quic.Stream + stream *quic.Stream Msg *dns.Msg } func (w *DoQWriter) Write(b []byte) (int, error) { + if w.stream == nil { + return 0, errors.New("stream is nil") + } b = AddPrefix(b) return w.stream.Write(b) } @@ -40,6 +44,9 @@ func (w *DoQWriter) WriteMsg(m *dns.Msg) error { // mechanism that no further data will be sent on that stream. // See https://www.rfc-editor.org/rfc/rfc9250#section-4.2-7 func (w *DoQWriter) Close() error { + if w.stream == nil { + return errors.New("stream is nil") + } return w.stream.Close() } diff --git a/core/dnsserver/quic_test.go b/core/dnsserver/quic_test.go index 4a4f408cd..7e7301906 100644 --- a/core/dnsserver/quic_test.go +++ b/core/dnsserver/quic_test.go @@ -1,15 +1,8 @@ package dnsserver import ( - "bytes" - "context" - "errors" "net" "testing" - "time" - - "github.com/miekg/dns" - "github.com/quic-go/quic-go" ) func TestDoQWriterAddPrefix(t *testing.T) { @@ -55,210 +48,3 @@ func TestDoQWriter_ResponseWriterMethods(t *testing.T) { t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr) } } - -// mockQuicStream is a mock implementation of quic.Stream for testing. -type mockQuicStream struct { - writer func(p []byte) (n int, err error) - closer func() error - closed bool - data []byte -} - -func (m *mockQuicStream) Write(p []byte) (n int, err error) { - m.data = append(m.data, p...) - if m.writer != nil { - return m.writer(p) - } - return len(p), nil -} - -func (m *mockQuicStream) Close() error { - m.closed = true - if m.closer != nil { - return m.closer() - } - return nil -} - -// Required by quic.Stream interface, but not used in these tests -func (m *mockQuicStream) Read(p []byte) (n int, err error) { return 0, nil } -func (m *mockQuicStream) CancelRead(code quic.StreamErrorCode) {} -func (m *mockQuicStream) CancelWrite(code quic.StreamErrorCode) {} -func (m *mockQuicStream) SetReadDeadline(t time.Time) error { return nil } -func (m *mockQuicStream) SetWriteDeadline(t time.Time) error { return nil } -func (m *mockQuicStream) SetDeadline(t time.Time) error { return nil } -func (m *mockQuicStream) StreamID() quic.StreamID { return 0 } -func (m *mockQuicStream) Context() context.Context { return nil } - -func TestDoQWriter_Write(t *testing.T) { - tests := []struct { - name string - input []byte - streamWriter func(p []byte) (n int, err error) - expectErr bool - expectedData []byte - expectedN int - }{ - { - name: "successful write", - input: []byte{0x1, 0x2, 0x3}, - streamWriter: func(p []byte) (n int, err error) { - return len(p), nil - }, - expectErr: false, - expectedData: []byte{0x0, 0x3, 0x1, 0x2, 0x3}, // 3-byte length prefix - expectedN: 5, - }, - { - name: "stream write error", - input: []byte{0x4, 0x5}, - streamWriter: func(p []byte) (n int, err error) { - return 0, errors.New("stream error") - }, - expectErr: true, - expectedData: []byte{0x0, 0x2, 0x4, 0x5}, // 2-byte length prefix - expectedN: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockStream := &mockQuicStream{writer: tt.streamWriter} - writer := &DoQWriter{stream: mockStream} - - n, err := writer.Write(tt.input) - - if (err != nil) != tt.expectErr { - t.Errorf("Write() error = %v, expectErr %v", err, tt.expectErr) - return - } - if n != tt.expectedN { - t.Errorf("Write() n = %v, want %v", n, tt.expectedN) - } - - if !bytes.Equal(mockStream.data, tt.expectedData) { - t.Errorf("Write() data written to stream = %X, want %X", mockStream.data, tt.expectedData) - } - }) - } -} - -func TestDoQWriter_WriteMsg(t *testing.T) { - newMsg := func() *dns.Msg { - m := new(dns.Msg) - m.SetQuestion("example.com.", dns.TypeA) - return m - } - - tests := []struct { - name string - msg *dns.Msg - mockStream *mockQuicStream - expectErr bool - expectClosed bool - expectedData []byte // Expected data written to stream (packed msg with prefix) - packErr bool // Simulate error during msg.Pack() - }{ - { - name: "successful write and close", - msg: newMsg(), - mockStream: &mockQuicStream{}, - expectErr: false, - expectClosed: true, - }, - { - name: "msg.Pack() error", - msg: new(dns.Msg), - mockStream: &mockQuicStream{}, - expectErr: true, - packErr: true, // We'll make msg.Pack() fail by corrupting the msg or using a mock - expectClosed: false, // Close should not be called if Pack fails - }, - { - name: "stream write error", - msg: newMsg(), - mockStream: &mockQuicStream{ - writer: func(p []byte) (n int, err error) { - return 0, errors.New("stream write failed") - }, - }, - expectErr: true, - expectClosed: false, // Close should not be called if Write fails - }, - { - name: "stream close error", - msg: newMsg(), - mockStream: &mockQuicStream{ - closer: func() error { - return errors.New("stream close failed") - }, - }, - expectErr: true, - expectClosed: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.packErr { - // Intentionally make the message invalid to cause a pack error. - // Invalid Rcode to ensure Pack fails. - tt.msg.Rcode = 1337 - } - - writer := &DoQWriter{stream: tt.mockStream, Msg: tt.msg} - err := writer.WriteMsg(tt.msg) - - if (err != nil) != tt.expectErr { - t.Errorf("WriteMsg() error = %v, expectErr %v", err, tt.expectErr) - } - - if tt.mockStream.closed != tt.expectClosed { - t.Errorf("WriteMsg() stream closed = %v, want %v", tt.mockStream.closed, tt.expectClosed) - } - - if tt.packErr { - if len(tt.mockStream.data) != 0 { - t.Errorf("WriteMsg() data written to stream on pack error = %X, want empty", tt.mockStream.data) - } - } - }) - } -} - -func TestDoQWriter_Close(t *testing.T) { - tests := []struct { - name string - mockStream *mockQuicStream - expectErr bool - }{ - { - name: "successful close", - mockStream: &mockQuicStream{}, - expectErr: false, - }, - { - name: "stream close error", - mockStream: &mockQuicStream{ - closer: func() error { - return errors.New("stream close error") - }, - }, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - writer := &DoQWriter{stream: tt.mockStream} - err := writer.Close() - - if (err != nil) != tt.expectErr { - t.Errorf("Close() error = %v, expectErr %v", err, tt.expectErr) - } - if !tt.mockStream.closed { - t.Errorf("Close() stream not marked as closed") - } - }) - } -} diff --git a/core/dnsserver/server_quic.go b/core/dnsserver/server_quic.go index 1a3b2c456..2e42155a9 100644 --- a/core/dnsserver/server_quic.go +++ b/core/dnsserver/server_quic.go @@ -129,7 +129,10 @@ func (s *ServerQUIC) ServeQUIC() error { // serveQUICConnection handles a new QUIC connection. It waits for new streams // and passes them to serveQUICStream. -func (s *ServerQUIC) serveQUICConnection(conn quic.Connection) { +func (s *ServerQUIC) serveQUICConnection(conn *quic.Conn) { + if conn == nil { + return + } for { // In DoQ, one query consumes one stream. // The client MUST select the next available client-initiated bidirectional @@ -147,14 +150,21 @@ func (s *ServerQUIC) serveQUICConnection(conn quic.Connection) { // Use a bounded worker pool s.streamProcessPool <- struct{}{} // Acquire a worker slot, may block - go func(st quic.Stream, cn quic.Connection) { + go func(st *quic.Stream, cn *quic.Conn) { defer func() { <-s.streamProcessPool }() // Release worker slot s.serveQUICStream(st, cn) }(stream, conn) } } -func (s *ServerQUIC) serveQUICStream(stream quic.Stream, conn quic.Connection) { +func (s *ServerQUIC) serveQUICStream(stream *quic.Stream, conn *quic.Conn) { + if conn == nil { + return + } + if stream == nil { + s.closeQUICConn(conn, DoQCodeInternalError) + return + } buf, err := readDOQMessage(stream) // io.EOF does not really mean that there's any error, it is just @@ -249,7 +259,7 @@ func (s *ServerQUIC) Serve(l net.Listener) error { return nil } func (s *ServerQUIC) Listen() (net.Listener, error) { return nil, nil } // closeQUICConn quietly closes the QUIC connection. -func (s *ServerQUIC) closeQUICConn(conn quic.Connection, code quic.ApplicationErrorCode) { +func (s *ServerQUIC) closeQUICConn(conn *quic.Conn, code quic.ApplicationErrorCode) { if conn == nil { return } diff --git a/core/dnsserver/server_quic_test.go b/core/dnsserver/server_quic_test.go index 96674721c..8deb11c7c 100644 --- a/core/dnsserver/server_quic_test.go +++ b/core/dnsserver/server_quic_test.go @@ -2,14 +2,9 @@ package dnsserver import ( "bytes" - "context" "crypto/tls" - "encoding/binary" "errors" - "io" - "net" "testing" - "time" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -373,74 +368,6 @@ func TestReadDOQMessage(t *testing.T) { } } -func TestDoQWriter(t *testing.T) { - mockStream := &mockQUICStream{} - localAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:53") - remoteAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345") - - writer := &DoQWriter{ - localAddr: localAddr, - remoteAddr: remoteAddr, - stream: mockStream, - } - - if writer.LocalAddr() != localAddr { - t.Errorf("LocalAddr() = %v, want %v", writer.LocalAddr(), localAddr) - } - - if writer.RemoteAddr() != remoteAddr { - t.Errorf("RemoteAddr() = %v, want %v", writer.RemoteAddr(), remoteAddr) - } - - testData := []byte("test message") - n, err := writer.Write(testData) - if err != nil { - t.Errorf("Write() failed: %v", err) - } - - expectedLen := len(testData) + 2 // +2 for length prefix - if n != expectedLen { - t.Errorf("Write() returned %d, want %d", n, expectedLen) - } - - // Verify the written data includes length prefix - written := mockStream.writtenData - if len(written) != expectedLen { - t.Errorf("Expected written data length %d, got %d", expectedLen, len(written)) - } - - // Check length prefix - expectedLength := uint16(len(testData)) - actualLength := binary.BigEndian.Uint16(written[:2]) - if actualLength != expectedLength { - t.Errorf("Expected length prefix %d, got %d", expectedLength, actualLength) - } - - // Check message content - if !bytes.Equal(written[2:], testData) { - t.Errorf("Expected message content %v, got %v", testData, written[2:]) - } - - // Test WriteMsg method - msg := new(dns.Msg) - msg.SetQuestion("example.com.", dns.TypeA) - msg.Id = 0 - - mockStream.reset() - err = writer.WriteMsg(msg) - if err != nil { - t.Errorf("WriteMsg() failed: %v", err) - } - - if !mockStream.closed { - t.Error("WriteMsg() should close the stream") - } - - if err := writer.TsigStatus(); err != nil { - t.Errorf("TsigStatus() returned error: %v", err) - } -} - func TestAddPrefix(t *testing.T) { tests := []struct { name string @@ -473,34 +400,3 @@ func TestAddPrefix(t *testing.T) { }) } } - -type mockQUICStream struct { - writtenData []byte - closed bool -} - -func (m *mockQUICStream) Write(data []byte) (int, error) { - m.writtenData = append(m.writtenData, data...) - return len(data), nil -} - -func (m *mockQUICStream) Read([]byte) (int, error) { return 0, io.EOF } - -func (m *mockQUICStream) Close() error { - m.closed = true - return nil -} - -func (m *mockQUICStream) reset() { - m.writtenData = nil - m.closed = false -} - -// Minimal implementation of other required methods -func (m *mockQUICStream) StreamID() quic.StreamID { return 0 } -func (m *mockQUICStream) SetReadDeadline(time.Time) error { return nil } -func (m *mockQUICStream) SetWriteDeadline(time.Time) error { return nil } -func (m *mockQUICStream) SetDeadline(time.Time) error { return nil } -func (m *mockQUICStream) Context() context.Context { return context.Background() } -func (m *mockQUICStream) CancelWrite(quic.StreamErrorCode) {} -func (m *mockQUICStream) CancelRead(quic.StreamErrorCode) {} diff --git a/go.mod b/go.mod index 9cf587a9a..f125791e2 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/prometheus/client_golang v1.22.0 github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.65.0 - github.com/quic-go/quic-go v0.52.0 + github.com/quic-go/quic-go v0.53.0 go.etcd.io/etcd/api/v3 v3.6.1 go.etcd.io/etcd/client/v3 v3.6.1 go.uber.org/automaxprocs v1.6.0 @@ -106,7 +106,6 @@ require ( github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.23.0 // indirect - github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-viper/mapstructure/v2 v2.3.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.2 // indirect @@ -114,7 +113,6 @@ require ( github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect diff --git a/go.sum b/go.sum index ed8cc9638..51b8e22be 100644 --- a/go.sum +++ b/go.sum @@ -162,6 +162,7 @@ github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= @@ -293,8 +294,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= -github.com/quic-go/quic-go v0.52.0 h1:/SlHrCRElyaU6MaEPKqKr9z83sBg2v4FLLvWM+Z47pA= -github.com/quic-go/quic-go v0.52.0/go.mod h1:MFlGGpcpJqRAfmYi6NC2cptDPSxRWTOGNuP4wqrWmzQ= +github.com/quic-go/quic-go v0.53.0 h1:QHX46sISpG2S03dPeZBgVIZp8dGagIaiu2FiVYvpCZI= +github.com/quic-go/quic-go v0.53.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= github.com/richardartoul/molecule v1.0.1-0.20240531184615-7ca0df43c0b3 h1:4+LEVOB87y175cLJC/mbsgKmoDOjrBldtXvioEy96WY= github.com/richardartoul/molecule v1.0.1-0.20240531184615-7ca0df43c0b3/go.mod h1:vl5+MqJ1nBINuSsUI2mGgH79UweUT/B5Fy8857PqyyI= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= diff --git a/test/quic_test.go b/test/quic_test.go index 1bc06a24c..e8d673d74 100644 --- a/test/quic_test.go +++ b/test/quic_test.go @@ -153,7 +153,7 @@ func TestQUICStreamLimits(t *testing.T) { var mu sync.Mutex // Create a slice to store all the streams so we can keep them open - streams := make([]quic.Stream, 0, streamCount) + streams := make([]*quic.Stream, 0, streamCount) streamsMu := sync.Mutex{} // Attempt to open exactly the configured number of streams