diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 7d6dc2427..456cd4544 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -433,7 +433,7 @@ func (rbw *responseBodyWrapper) Close() error { // b.Dialer().PeerAPITransport() with metrics tracking. type driveTransport struct { b *LocalBackend - tr *http.Transport + tr http.RoundTripper } func (b *LocalBackend) newDriveTransport() *driveTransport { @@ -443,7 +443,7 @@ func (b *LocalBackend) newDriveTransport() *driveTransport { } } -func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { +func (dt *driveTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Some WebDAV clients include origin and refer headers, which peerapi does // not like. Remove them. req.Header.Del("origin") @@ -455,42 +455,45 @@ func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err req.Body = bw } - defer func() { - contentType := "unknown" - if ct := req.Header.Get("Content-Type"); ct != "" { - contentType = ct - } + resp, err := dt.tr.RoundTrip(req) + if err != nil { + return nil, err + } - dt.b.mu.Lock() - selfNodeKey := dt.b.currentNode().Self().Key().ShortString() - dt.b.mu.Unlock() - n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) - shareNodeKey := "unknown" - if ok { - shareNodeKey = string(n.Key().ShortString()) - } + contentType := "unknown" + if ct := req.Header.Get("Content-Type"); ct != "" { + contentType = ct + } - rbw := responseBodyWrapper{ - log: dt.b.logf, - logVerbose: req.Method != httpm.GET && req.Method != httpm.PUT, // other requests like PROPFIND are quite chatty, so we log those at verbose level - method: req.Method, - bytesTx: int64(bw.bytesRead), - selfNodeKey: selfNodeKey, - shareNodeKey: shareNodeKey, - contentType: contentType, - contentLength: resp.ContentLength, - fileExtension: parseDriveFileExtensionForLog(req.URL.Path), - statusCode: resp.StatusCode, - ReadCloser: resp.Body, - } + dt.b.mu.Lock() + selfNodeKey := dt.b.currentNode().Self().Key().ShortString() + dt.b.mu.Unlock() + n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) + shareNodeKey := "unknown" + if ok { + shareNodeKey = string(n.Key().ShortString()) + } - if resp.StatusCode >= 400 { - // in case of error response, just log immediately - rbw.logAccess("") - } else { - resp.Body = &rbw - } - }() + rbw := responseBodyWrapper{ + log: dt.b.logf, + logVerbose: req.Method != httpm.GET && req.Method != httpm.PUT, // other requests like PROPFIND are quite chatty, so we log those at verbose level + method: req.Method, + bytesTx: int64(bw.bytesRead), + selfNodeKey: selfNodeKey, + shareNodeKey: shareNodeKey, + contentType: contentType, + contentLength: resp.ContentLength, + fileExtension: parseDriveFileExtensionForLog(req.URL.Path), + statusCode: resp.StatusCode, + ReadCloser: resp.Body, + } - return dt.tr.RoundTrip(req) + if resp.StatusCode >= 400 { + // in case of error response, just log immediately + rbw.logAccess("") + } else { + resp.Body = &rbw + } + + return resp, nil } diff --git a/ipn/ipnlocal/drive_test.go b/ipn/ipnlocal/drive_test.go new file mode 100644 index 000000000..323c38214 --- /dev/null +++ b/ipn/ipnlocal/drive_test.go @@ -0,0 +1,50 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive + +package ipnlocal + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" +) + +// TestDriveTransportRoundTrip_NetworkError tests that driveTransport.RoundTrip +// doesn't panic when the underlying transport returns a nil response with an +// error. +// +// See: https://github.com/tailscale/tailscale/issues/17306 +func TestDriveTransportRoundTrip_NetworkError(t *testing.T) { + b := newTestLocalBackend(t) + + testErr := errors.New("network connection failed") + mockTransport := &mockRoundTripper{ + err: testErr, + } + dt := &driveTransport{ + b: b, + tr: mockTransport, + } + + req := httptest.NewRequest("GET", "http://100.64.0.1:1234/some/path", nil) + resp, err := dt.RoundTrip(req) + if err == nil { + t.Fatal("got nil error, expected non-nil") + } else if !errors.Is(err, testErr) { + t.Errorf("got error %v, expected %v", err, testErr) + } + if resp != nil { + t.Errorf("wanted nil response, got %v", resp) + } +} + +type mockRoundTripper struct { + err error +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, m.err +}