Skip to content

Commit

Permalink
Add -dev-tls-san flag (#22657)
Browse files Browse the repository at this point in the history
* Add -dev-tls-san flag

This is helpful when wanting to set up a dev server with TLS in Kubernetes
and any other situations where the dev server may not be the same machine
as the Vault client (e.g. in combination with some /etc/hosts entries)

* Automatically add (best-effort only) -dev-listen-address host to extraSANs
  • Loading branch information
tomhjp committed Aug 31, 2023
1 parent 8da06f9 commit 8764921
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 13 deletions.
3 changes: 3 additions & 0 deletions changelog/22657.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
command/server: add `-dev-tls-san` flag to configure subject alternative names for the certificate generated when using `-dev-tls`.
```
25 changes: 24 additions & 1 deletion command/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ type ServerCommand struct {
flagDev bool
flagDevTLS bool
flagDevTLSCertDir string
flagDevTLSSANs []string
flagDevRootTokenID string
flagDevListenAddr string
flagDevNoStoreToken bool
Expand Down Expand Up @@ -256,6 +257,18 @@ func (c *ServerCommand) Flags() *FlagSets {
"specified. If left unset, files are generated in a temporary directory.",
})

f.StringSliceVar(&StringSliceVar{
Name: "dev-tls-san",
Target: &c.flagDevTLSSANs,
Default: nil,
Usage: "Additional Subject Alternative Name (as a DNS name or IP address) " +
"to generate the certificate with if `-dev-tls` is specified. The " +
"certificate will always use localhost, localhost4, localhost6, " +
"localhost.localdomain, and the host name as alternate DNS names, " +
"and 127.0.0.1 as an alternate IP address. This flag can be specified " +
"multiple times to specify multiple SANs.",
})

f.StringVar(&StringVar{
Name: "dev-root-token-id",
Target: &c.flagDevRootTokenID,
Expand Down Expand Up @@ -977,7 +990,17 @@ func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) {
return nil, nil, certDir, err
}
}
config, err = server.DevTLSConfig(devStorageType, certDir)
extraSANs := c.flagDevTLSSANs
host, _, err := net.SplitHostPort(c.flagDevListenAddr)
if err == nil {
// 127.0.0.1 is the default, and already included in the SANs.
// Empty host means listen on all interfaces, but users should use the
// -dev-tls-san flag to get the right SANs in that case.
if host != "" && host != "127.0.0.1" {
extraSANs = append(extraSANs, host)
}
}
config, err = server.DevTLSConfig(devStorageType, certDir, extraSANs)

f = func() {
if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions command/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ ui = true
}

// DevTLSConfig is a Config that is used for dev tls mode of Vault.
func DevTLSConfig(storageType, certDir string) (*Config, error) {
func DevTLSConfig(storageType, certDir string, extraSANs []string) (*Config, error) {
ca, err := GenerateCA()
if err != nil {
return nil, err
}

cert, key, err := GenerateCert(ca.Template, ca.Signer)
cert, key, err := generateCert(ca.Template, ca.Signer, extraSANs)
if err != nil {
return nil, err
}
Expand Down
11 changes: 9 additions & 2 deletions command/server/tls_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ type CaCert struct {
Signer crypto.Signer
}

// GenerateCert creates a new leaf cert from provided CA template and signer
func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (string, string, error) {
// generateCert creates a new leaf cert from provided CA template and signer
func generateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer, extraSANs []string) (string, string, error) {
// Create the private key
signer, keyPEM, err := privateKey()
if err != nil {
Expand Down Expand Up @@ -80,6 +80,13 @@ func GenerateCert(caCertTemplate *x509.Certificate, caSigner crypto.Signer) (str
if !foundHostname {
template.DNSNames = append(template.DNSNames, hostname)
}
for _, san := range extraSANs {
if ip := net.ParseIP(san); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, san)
}
}

bs, err := x509.CreateCertificate(
rand.Reader, &template, caCertTemplate, signer.Public(), caSigner)
Expand Down
80 changes: 80 additions & 0 deletions command/server/tls_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package server

import (
"crypto/x509"
"encoding/pem"
"testing"

"github.com/hashicorp/go-secure-stdlib/strutil"
)

// TestGenerateCertExtraSans ensures the implementation backing the flag
// -dev-tls-san populates alternate DNS and IP address names in the generated
// certificate as expected.
func TestGenerateCertExtraSans(t *testing.T) {
ca, err := GenerateCA()
if err != nil {
t.Fatal(err)
}

for name, tc := range map[string]struct {
extraSans []string
expectedDNSNames []string
expectedIPAddresses []string
}{
"empty": {},
"DNS names": {
extraSans: []string{"foo", "foo.bar"},
expectedDNSNames: []string{"foo", "foo.bar"},
},
"IP addresses": {
extraSans: []string{"0.0.0.0", "::1"},
expectedIPAddresses: []string{"0.0.0.0", "::1"},
},
"mixed": {
extraSans: []string{"bar", "0.0.0.0", "::1"},
expectedDNSNames: []string{"bar"},
expectedIPAddresses: []string{"0.0.0.0", "::1"},
},
} {
t.Run(name, func(t *testing.T) {
certStr, _, err := generateCert(ca.Template, ca.Signer, tc.extraSans)
if err != nil {
t.Fatal(err)
}

block, _ := pem.Decode([]byte(certStr))
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
t.Fatal(err)
}

expectedDNSNamesLen := len(tc.expectedDNSNames) + 5
if len(cert.DNSNames) != expectedDNSNamesLen {
t.Errorf("Wrong number of DNS names, expected %d but got %v", expectedDNSNamesLen, cert.DNSNames)
}
expectedIPAddrLen := len(tc.expectedIPAddresses) + 1
if len(cert.IPAddresses) != expectedIPAddrLen {
t.Errorf("Wrong number of IP addresses, expected %d but got %v", expectedIPAddrLen, cert.IPAddresses)
}

for _, expected := range tc.expectedDNSNames {
if !strutil.StrListContains(cert.DNSNames, expected) {
t.Errorf("Missing DNS name %s", expected)
}
}
for _, expected := range tc.expectedIPAddresses {
var found bool
for _, ip := range cert.IPAddresses {
if ip.String() == expected {
found = true
break
}
}
if !found {
t.Errorf("Missing IP address %s", expected)
}
}
})
}
}
23 changes: 15 additions & 8 deletions command/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/hashicorp/vault/sdk/physical"
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -383,13 +384,19 @@ func TestConfigureDevTLS(t *testing.T) {
fun()
}

require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
if testcase.ConfigNotNil {
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
}
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
t.Run(testcase.TestDescription, func(t *testing.T) {
assert.Equal(t, testcase.DeferFuncNotNil, (fun != nil))
assert.Equal(t, testcase.ConfigNotNil, cfg != nil)
if testcase.ConfigNotNil && cfg != nil {
assert.True(t, len(cfg.Listeners) > 0)
assert.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable)
}
assert.Equal(t, testcase.CertPathEmpty, len(certPath) == 0)
if testcase.ErrNotNil {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

0 comments on commit 8764921

Please sign in to comment.