diff --git a/integration/cluster.go b/integration/cluster.go index fd7330a8533..7af9d77a1a3 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -76,6 +76,13 @@ var ( ClientCertAuth: true, } + testTLSInfoExpired = transport.TLSInfo{ + KeyFile: "./fixtures-expired/server-key.pem", + CertFile: "./fixtures-expired/server.pem", + TrustedCAFile: "./fixtures-expired/etcd-root-ca.pem", + ClientCertAuth: true, + } + plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "integration") ) diff --git a/integration/util_test.go b/integration/util_test.go new file mode 100644 index 00000000000..18894198016 --- /dev/null +++ b/integration/util_test.go @@ -0,0 +1,62 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "io" + "os" + "path/filepath" + + "github.com/coreos/etcd/pkg/transport" +) + +// copyTLSFiles clones certs files to dst directory. +func copyTLSFiles(ti transport.TLSInfo, dst string) (transport.TLSInfo, error) { + ci := transport.TLSInfo{ + KeyFile: filepath.Join(dst, "server-key.pem"), + CertFile: filepath.Join(dst, "server.pem"), + TrustedCAFile: filepath.Join(dst, "etcd-root-ca.pem"), + ClientCertAuth: ti.ClientCertAuth, + } + if err := copyFile(ti.KeyFile, ci.KeyFile); err != nil { + return transport.TLSInfo{}, err + } + if err := copyFile(ti.CertFile, ci.CertFile); err != nil { + return transport.TLSInfo{}, err + } + if err := copyFile(ti.TrustedCAFile, ci.TrustedCAFile); err != nil { + return transport.TLSInfo{}, err + } + return ci, nil +} + +func copyFile(src, dst string) error { + f, err := os.Open(src) + if err != nil { + return err + } + defer f.Close() + + w, err := os.Create(dst) + if err != nil { + return err + } + defer w.Close() + + if _, err = io.Copy(w, f); err != nil { + return err + } + return w.Sync() +} diff --git a/integration/v3_grpc_test.go b/integration/v3_grpc_test.go index 5113821def1..6ebc82d23c4 100644 --- a/integration/v3_grpc_test.go +++ b/integration/v3_grpc_test.go @@ -16,17 +16,21 @@ package integration import ( "bytes" + "crypto/tls" "fmt" + "io/ioutil" "math/rand" "os" "reflect" "testing" "time" + "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/etcdserver/api/v3rpc" "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" pb "github.com/coreos/etcd/etcdserver/etcdserverpb" "github.com/coreos/etcd/pkg/testutil" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -1374,6 +1378,207 @@ func TestTLSGRPCAcceptSecureAll(t *testing.T) { } } +// TestTLSReloadAtomicReplace ensures server reloads expired/valid certs +// when all certs are atomically replaced by directory renaming. +// And expects server to reject client requests, and vice versa. +func TestTLSReloadAtomicReplace(t *testing.T) { + defer testutil.AfterTest(t) + + // clone valid,expired certs to separate directories for atomic renaming + vDir, err := ioutil.TempDir(os.TempDir(), "fixtures-valid") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(vDir) + ts, err := copyTLSFiles(testTLSInfo, vDir) + if err != nil { + t.Fatal(err) + } + eDir, err := ioutil.TempDir(os.TempDir(), "fixtures-expired") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(eDir) + if _, err = copyTLSFiles(testTLSInfoExpired, eDir); err != nil { + t.Fatal(err) + } + + tDir, err := ioutil.TempDir(os.TempDir(), "fixtures") + if err != nil { + t.Fatal(err) + } + os.RemoveAll(tDir) + defer os.RemoveAll(tDir) + + // start with valid certs + clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) + defer clus.Terminate(t) + + // concurrent client dialing while certs transition from valid to expired + errc := make(chan error, 1) + go func() { + for { + cc, err := ts.ClientConfig() + if err != nil { + if os.IsNotExist(err) { + // from concurrent renaming + continue + } + t.Fatal(err) + } + cli, cerr := clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: cc, + }) + if cerr != nil { + errc <- cerr + return + } + cli.Close() + } + }() + + // replace certs directory with expired ones + if err = os.Rename(vDir, tDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(eDir, vDir); err != nil { + t.Fatal(err) + } + + // after rename, + // 'vDir' contains expired certs + // 'tDir' contains valid certs + // 'eDir' does not exist + + select { + case err = <-errc: + if err != grpc.ErrClientConnTimeout { + t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err) + } + case <-time.After(5 * time.Second): + t.Fatal("failed to receive dial timeout error") + } + + // now, replace expired certs back with valid ones + if err = os.Rename(tDir, eDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(vDir, tDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(eDir, vDir); err != nil { + t.Fatal(err) + } + + // new incoming client request should trigger + // listener to reload valid certs + var tls *tls.Config + tls, err = ts.ClientConfig() + if err != nil { + t.Fatal(err) + } + var cl *clientv3.Client + cl, err = clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: tls, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + cl.Close() +} + +// TestTLSReloadCopy ensures server reloads expired/valid certs +// when new certs are copied over, one by one. And expects server +// to reject client requests, and vice versa. +func TestTLSReloadCopy(t *testing.T) { + defer testutil.AfterTest(t) + + // clone certs directory, free to overwrite + cDir, err := ioutil.TempDir(os.TempDir(), "fixtures-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cDir) + ts, err := copyTLSFiles(testTLSInfo, cDir) + if err != nil { + t.Fatal(err) + } + + // start with valid certs + clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) + defer clus.Terminate(t) + + // concurrent client dialing while certs transition from valid to expired + errc := make(chan error, 1) + go func() { + for { + cc, err := ts.ClientConfig() + if err != nil { + // from concurrent certs overwriting + switch err.Error() { + case "tls: private key does not match public key": + fallthrough + case "tls: failed to find any PEM data in key input": + continue + } + t.Fatal(err) + } + cli, cerr := clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: cc, + }) + if cerr != nil { + errc <- cerr + return + } + cli.Close() + } + }() + + // overwrite valid certs with expired ones + // (e.g. simulate cert expiration in practice) + if _, err = copyTLSFiles(testTLSInfoExpired, cDir); err != nil { + t.Fatal(err) + } + + select { + case gerr := <-errc: + if gerr != grpc.ErrClientConnTimeout { + t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, gerr) + } + case <-time.After(5 * time.Second): + t.Fatal("failed to receive dial timeout error") + } + + // now, replace expired certs back with valid ones + if _, err = copyTLSFiles(testTLSInfo, cDir); err != nil { + t.Fatal(err) + } + + // new incoming client request should trigger + // listener to reload valid certs + var tls *tls.Config + tls, err = ts.ClientConfig() + if err != nil { + t.Fatal(err) + } + var cl *clientv3.Client + cl, err = clientv3.New(clientv3.Config{ + Endpoints: []string{clus.Members[0].GRPCAddr()}, + DialTimeout: time.Second, + TLS: tls, + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + cl.Close() +} + func TestGRPCRequireLeader(t *testing.T) { defer testutil.AfterTest(t)