Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reading Vault's address from Agent's config file #6306

Merged
merged 8 commits into from
Feb 28, 2019
Merged
9 changes: 5 additions & 4 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import (
"golang.org/x/time/rate"
)

const EnvVaultAgentAddress = "VAULT_AGENT_ADDR"
const EnvVaultAddress = "VAULT_ADDR"
const EnvVaultAgentAddr = "VAULT_AGENT_ADDR"
const EnvVaultCACert = "VAULT_CACERT"
const EnvVaultCAPath = "VAULT_CAPATH"
const EnvVaultClientCert = "VAULT_CLIENT_CERT"
const EnvVaultClientKey = "VAULT_CLIENT_KEY"
const EnvVaultClientTimeout = "VAULT_CLIENT_TIMEOUT"
const EnvVaultInsecure = "VAULT_SKIP_VERIFY"
const EnvVaultSkipVerify = "VAULT_SKIP_VERIFY"
const EnvVaultNamespace = "VAULT_NAMESPACE"
const EnvVaultTLSServerName = "VAULT_TLS_SERVER_NAME"
const EnvVaultWrapTTL = "VAULT_WRAP_TTL"
const EnvVaultMaxRetries = "VAULT_MAX_RETRIES"
Expand Down Expand Up @@ -243,7 +244,7 @@ func (c *Config) ReadEnvironment() error {
if v := os.Getenv(EnvVaultAddress); v != "" {
envAddress = v
}
if v := os.Getenv(EnvVaultAgentAddress); v != "" {
if v := os.Getenv(EnvVaultAgentAddr); v != "" {
envAgentAddress = v
}
if v := os.Getenv(EnvVaultMaxRetries); v != "" {
Expand Down Expand Up @@ -279,7 +280,7 @@ func (c *Config) ReadEnvironment() error {
}
envClientTimeout = clientTimeout
}
if v := os.Getenv(EnvVaultInsecure); v != "" {
if v := os.Getenv(EnvVaultSkipVerify); v != "" {
var err error
envInsecure, err = strconv.ParseBool(v)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,19 @@ func TestClientEnvSettings(t *testing.T) {
oldCAPath := os.Getenv(EnvVaultCAPath)
oldClientCert := os.Getenv(EnvVaultClientCert)
oldClientKey := os.Getenv(EnvVaultClientKey)
oldSkipVerify := os.Getenv(EnvVaultInsecure)
oldSkipVerify := os.Getenv(EnvVaultSkipVerify)
oldMaxRetries := os.Getenv(EnvVaultMaxRetries)
os.Setenv(EnvVaultCACert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultCAPath, cwd+"/test-fixtures/keys")
os.Setenv(EnvVaultClientCert, cwd+"/test-fixtures/keys/cert.pem")
os.Setenv(EnvVaultClientKey, cwd+"/test-fixtures/keys/key.pem")
os.Setenv(EnvVaultInsecure, "true")
os.Setenv(EnvVaultSkipVerify, "true")
os.Setenv(EnvVaultMaxRetries, "5")
defer os.Setenv(EnvVaultCACert, oldCACert)
defer os.Setenv(EnvVaultCAPath, oldCAPath)
defer os.Setenv(EnvVaultClientCert, oldClientCert)
defer os.Setenv(EnvVaultClientKey, oldClientKey)
defer os.Setenv(EnvVaultInsecure, oldSkipVerify)
defer os.Setenv(EnvVaultSkipVerify, oldSkipVerify)
defer os.Setenv(EnvVaultMaxRetries, oldMaxRetries)

config := DefaultConfig()
Expand Down
79 changes: 79 additions & 0 deletions command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package command

import (
"context"
"flag"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -206,6 +207,33 @@ func (c *AgentCommand) Run(args []string) int {
return 1
}

if config.Vault != nil {
c.setStringFlag(f, config.Vault.Address, &StringVar{
Name: flagNameAddress,
Target: &c.flagAddress,
Default: "https://127.0.0.1:8200",
EnvVar: api.EnvVaultAddress,
})
c.setStringFlag(f, config.Vault.CACert, &StringVar{
Name: flagNameCACert,
Target: &c.flagCACert,
Default: "",
EnvVar: api.EnvVaultCACert,
})
c.setStringFlag(f, config.Vault.CAPath, &StringVar{
Name: flagNameCAPath,
Target: &c.flagCAPath,
Default: "",
EnvVar: api.EnvVaultCAPath,
})
c.setBoolFlag(f, config.Vault.TLSSkipVerify, &BoolVar{
Name: flagNameTLSSkipVerify,
Target: &c.flagTLSSkipVerify,
Default: false,
EnvVar: api.EnvVaultSkipVerify,
})
}

infoKeys := make([]string, 0, 10)
info := make(map[string]string)
info["log level"] = c.flagLogLevel
Expand Down Expand Up @@ -235,6 +263,9 @@ func (c *AgentCommand) Run(args []string) int {
return 0
}

// Ignore any setting of agent's address. This client is used by the agent
// to reach out to Vault. This should never loop back to agent.
c.flagAgentAddress = ""
client, err := c.Client()
if err != nil {
c.UI.Error(fmt.Sprintf(
Expand Down Expand Up @@ -472,6 +503,54 @@ func (c *AgentCommand) Run(args []string) int {
return 0
}

func (c *AgentCommand) setStringFlag(f *FlagSets, configVal string, fVar *StringVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
if f.Name == fVar.Name {
isFlagSet = true
}
})

flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
switch {
case isFlagSet:
// Don't do anything as the flag is already set from the command line
case flagEnvSet:
// Use value from env var
*fVar.Target = flagEnvValue
case configVal != "":
// Use value from config
*fVar.Target = configVal
default:
// Use the default value
*fVar.Target = fVar.Default
}
}

func (c *AgentCommand) setBoolFlag(f *FlagSets, configVal bool, fVar *BoolVar) {
var isFlagSet bool
f.Visit(func(f *flag.Flag) {
if f.Name == fVar.Name {
isFlagSet = true
}
})

flagEnvValue, flagEnvSet := os.LookupEnv(fVar.EnvVar)
switch {
case isFlagSet:
// Don't do anything as the flag is already set from the command line
case flagEnvSet:
// Use value from env var
*fVar.Target = flagEnvValue != ""
case configVal == true:
// Use value from config
*fVar.Target = configVal
default:
// Use the default value
*fVar.Target = fVar.Default
}
}

// storePidFile is used to write out our PID to a file if necessary
func (c *AgentCommand) storePidFile(pidPath string) error {
// Quit fast if no pidfile
Expand Down
34 changes: 34 additions & 0 deletions command/agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ type Config struct {
ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
Cache *Cache `hcl:"cache"`
Vault *Vault `hcl:"vault"`
}

type Vault struct {
Address string `hcl:"address"`
CACert string `hcl:"ca_cert"`
CAPath string `hcl:"ca_path"`
TLSSkipVerify bool `hcl:"tls_skip_verify"`
}

type Cache struct {
Expand Down Expand Up @@ -107,9 +115,35 @@ func LoadConfig(path string, logger log.Logger) (*Config, error) {
return nil, errwrap.Wrapf("error parsing 'cache':{{err}}", err)
}

err = parseVault(&result, list)
if err != nil {
return nil, errwrap.Wrapf("error parsing 'vault':{{err}}", err)
}

return &result, nil
}

func parseVault(result *Config, list *ast.ObjectList) error {
name := "vault"

vaultList := list.Filter(name)
if len(vaultList.Items) > 1 {
return fmt.Errorf("one and only one %q block is required", name)
}

item := vaultList.Items[0]

var v Vault
err := hcl.DecodeObject(&v, item.Val)
if err != nil {
return err
}

result.Vault = &v

return nil
}

func parseCache(result *Config, list *ast.ObjectList) error {
name := "cache"

Expand Down
6 changes: 6 additions & 0 deletions command/agent/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ func TestLoadConfigFile_AgentCache(t *testing.T) {
},
},
},
Vault: &Vault{
Address: "http://127.0.0.1:1111",
CACert: "config_ca_cert",
CAPath: "config_ca_path",
TLSSkipVerify: true,
},
PidFile: "./pidfile",
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ cache {
tls_cert_file = "/path/to/cacert.pem"
}
}

vault {
address = "http://127.0.0.1:1111"
ca_cert = "config_ca_cert"
ca_path = "config_ca_path"
tls_skip_verify = "true"
}
7 changes: 7 additions & 0 deletions command/agent/config/test-fixtures/config-cache.hcl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ cache {
tls_cert_file = "/path/to/cacert.pem"
}
}

vault {
address = "http://127.0.0.1:1111"
ca_cert = "config_ca_cert"
ca_path = "config_ca_path"
tls_skip_verify = "true"
}
6 changes: 3 additions & 3 deletions command/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,15 @@ cache {
}
}()

originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddress)
originalVaultAgentAddress := os.Getenv(api.EnvVaultAgentAddr)

// Create a client that talks to the agent
os.Setenv(api.EnvVaultAgentAddress, socketf)
os.Setenv(api.EnvVaultAgentAddr, socketf)
testClient, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
os.Setenv(api.EnvVaultAgentAddress, originalVaultAgentAddress)
os.Setenv(api.EnvVaultAgentAddr, originalVaultAgentAddress)

// Start the agent
go cmd.Run([]string{"-config", conf})
Expand Down
28 changes: 14 additions & 14 deletions command/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
f := set.NewFlagSet("HTTP Options")

addrStringVar := &StringVar{
Name: "address",
Name: flagNameAddress,
Target: &c.flagAddress,
EnvVar: "VAULT_ADDR",
EnvVar: api.EnvVaultAddress,
Completion: complete.PredictAnything,
Usage: "Address of the Vault server.",
}
Expand All @@ -227,28 +227,28 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
agentAddrStringVar := &StringVar{
Name: "agent-address",
Target: &c.flagAgentAddress,
EnvVar: "VAULT_AGENT_ADDR",
EnvVar: api.EnvVaultAgentAddr,
Completion: complete.PredictAnything,
Usage: "Address of the Agent.",
}
f.StringVar(agentAddrStringVar)

f.StringVar(&StringVar{
Name: "ca-cert",
Name: flagNameCACert,
Target: &c.flagCACert,
Default: "",
EnvVar: "VAULT_CACERT",
EnvVar: api.EnvVaultCACert,
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to verify the Vault server's SSL certificate. This " +
"takes precedence over -ca-path.",
})

f.StringVar(&StringVar{
Name: "ca-path",
Name: flagNameCAPath,
Target: &c.flagCAPath,
Default: "",
EnvVar: "VAULT_CAPATH",
EnvVar: api.EnvVaultCAPath,
Completion: complete.PredictDirs("*"),
Usage: "Path on the local disk to a directory of PEM-encoded CA " +
"certificates to verify the Vault server's SSL certificate.",
Expand All @@ -258,7 +258,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "client-cert",
Target: &c.flagClientCert,
Default: "",
EnvVar: "VAULT_CLIENT_CERT",
EnvVar: api.EnvVaultClientCert,
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded CA " +
"certificate to use for TLS authentication to the Vault server. If " +
Expand All @@ -269,7 +269,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "client-key",
Target: &c.flagClientKey,
Default: "",
EnvVar: "VAULT_CLIENT_KEY",
EnvVar: api.EnvVaultClientKey,
Completion: complete.PredictFiles("*"),
Usage: "Path on the local disk to a single PEM-encoded private key " +
"matching the client certificate from -client-cert.",
Expand All @@ -279,7 +279,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "namespace",
Target: &c.flagNamespace,
Default: notSetValue, // this can never be a real value
EnvVar: "VAULT_NAMESPACE",
EnvVar: api.EnvVaultNamespace,
Completion: complete.PredictAnything,
Usage: "The namespace to use for the command. Setting this is not " +
"necessary but allows using relative paths. -ns can be used as " +
Expand All @@ -299,17 +299,17 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "tls-server-name",
Target: &c.flagTLSServerName,
Default: "",
EnvVar: "VAULT_TLS_SERVER_NAME",
EnvVar: api.EnvVaultTLSServerName,
Completion: complete.PredictAnything,
Usage: "Name to use as the SNI host when connecting to the Vault " +
"server via TLS.",
})

f.BoolVar(&BoolVar{
Name: "tls-skip-verify",
Name: flagNameTLSSkipVerify,
Target: &c.flagTLSSkipVerify,
Default: false,
EnvVar: "VAULT_SKIP_VERIFY",
EnvVar: api.EnvVaultSkipVerify,
Usage: "Disable verification of TLS certificates. Using this option " +
"is highly discouraged and decreases the security of data " +
"transmissions to and from the Vault server.",
Expand All @@ -327,7 +327,7 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Name: "wrap-ttl",
Target: &c.flagWrapTTL,
Default: 0,
EnvVar: "VAULT_WRAP_TTL",
EnvVar: api.EnvVaultWrapTTL,
Completion: complete.PredictAnything,
Usage: "Wraps the response in a cubbyhole token with the requested " +
"TTL. The response is available via the \"vault unwrap\" command. " +
Expand Down
2 changes: 1 addition & 1 deletion command/base_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ func (f *FlagSet) VarFlag(i *VarFlag) {
}

// Var is a lower-level API for adding something to the flags. It should be used
// wtih caution, since it bypasses all validation. Consider VarFlag instead.
// with caution, since it bypasses all validation. Consider VarFlag instead.
func (f *FlagSet) Var(value flag.Value, name, usage string) {
f.mainSet.Var(value, name, usage)
f.flagSet.Var(value, name, usage)
Expand Down
Loading