diff --git a/command/server/listener.go b/command/server/listener.go index 28b063a0ef..f723bedd0c 100644 --- a/command/server/listener.go +++ b/command/server/listener.go @@ -1,6 +1,7 @@ package server import ( + "crypto/tls" "fmt" "net" ) @@ -23,3 +24,32 @@ func NewListener(t string, config map[string]string) (net.Listener, error) { return f(config) } + +func listenerWrapTLS( + ln net.Listener, config map[string]string) (net.Listener, error) { + if v, ok := config["tls_disable"]; ok && v != "" { + return ln, nil + } + + certFile, ok := config["tls_cert_file"] + if !ok { + return nil, fmt.Errorf("'tls_cert_file' must be set") + } + + keyFile, ok := config["tls_key_file"] + if !ok { + return nil, fmt.Errorf("'tls_key_file' must be set") + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("error loading TLS cert: %s", err) + } + + tlsConf := &tls.Config{} + tlsConf.Certificates = []tls.Certificate{cert} + tlsConf.NextProtos = []string{"http/1.1"} + + ln = tls.NewListener(ln, tlsConf) + return ln, nil +} diff --git a/command/server/listener_tcp.go b/command/server/listener_tcp.go index cff4d4c30a..9f0325064f 100644 --- a/command/server/listener_tcp.go +++ b/command/server/listener_tcp.go @@ -16,5 +16,5 @@ func tcpListenerFactory(config map[string]string) (net.Listener, error) { return nil, err } - return ln, nil + return listenerWrapTLS(ln, config) } diff --git a/command/server/listener_tcp_test.go b/command/server/listener_tcp_test.go index 61309a74a1..54457b66ff 100644 --- a/command/server/listener_tcp_test.go +++ b/command/server/listener_tcp_test.go @@ -7,7 +7,27 @@ import ( func TestTCPListener(t *testing.T) { ln, err := tcpListenerFactory(map[string]string{ - "address": "127.0.0.1:0", + "address": "127.0.0.1:0", + "tls_disable": "1", + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + connFn := func(lnReal net.Listener) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + } + + testListenerImpl(t, ln, connFn) +} + +func TestTCPListener_tls(t *testing.T) { + // TODO + t.Skip() + + ln, err := tcpListenerFactory(map[string]string{ + "address": "127.0.0.1:0", + "tls_disable": "1", }) if err != nil { t.Fatalf("err: %s", err)