Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exit-after-auth functionality to agent #5013

Merged
merged 1 commit into from
Jul 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ func (c *AgentCommand) Run(args []string) int {
return 1
}

c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level)
if c.logger == nil {
c.logger = logging.NewVaultLoggerWithWriter(c.logWriter, level)
}

// Validation
if len(c.flagConfigs) != 1 {
Expand Down Expand Up @@ -313,8 +315,9 @@ func (c *AgentCommand) Run(args []string) int {
}

ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: c.logger.Named("sink.server"),
Client: client,
Logger: c.logger.Named("sink.server"),
Client: client,
ExitAfterAuth: config.ExitAfterAuth,
})

ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{
Expand Down Expand Up @@ -342,6 +345,9 @@ func (c *AgentCommand) Run(args []string) int {
}()

select {
case <-ss.DoneCh:
// This will happen if we exit-on-auth
c.logger.Info("sinks finished, exiting")
case <-c.ShutdownCh:
c.UI.Output("==> Vault agent shutdown triggered")
cancelFunc()
Expand Down
5 changes: 3 additions & 2 deletions command/agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ import (

// Config is the configuration for the vault server.
type Config struct {
AutoAuth *AutoAuth `hcl:"auto_auth"`
PidFile string `hcl:"pid_file"`
AutoAuth *AutoAuth `hcl:"auto_auth"`
ExitAfterAuth bool `hcl:"exit_after_auth"`
PidFile string `hcl:"pid_file"`
}

type AutoAuth struct {
Expand Down
64 changes: 3 additions & 61 deletions command/agent/jwt_end_to_end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package agent

import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"io/ioutil"
"os"
"testing"
Expand All @@ -24,50 +21,8 @@ import (
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)

func getTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
t.Helper()
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: "https://team-vault.auth0.com/",
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"},
}

privateCl := struct {
User string `json:"https://vault/user"`
Groups []string `json:"https://vault/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}

var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(ecdsaPrivKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
}

sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
t.Fatal(err)
}

raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize()
if err != nil {
t.Fatal(err)
}

return raw, key
}

func TestJWTEndToEnd(t *testing.T) {
testJWTEndToEnd(t, false)
testJWTEndToEnd(t, true)
Expand Down Expand Up @@ -100,7 +55,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {

_, err = client.Logical().Write("auth/jwt/config", map[string]interface{}{
"bound_issuer": "https://team-vault.auth0.com/",
"jwt_validation_pubkeys": ecdsaPubKey,
"jwt_validation_pubkeys": TestECDSAPubKey,
})
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -248,7 +203,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
}

// Get a token
jwtToken, _ := getTestJWT(t)
jwtToken, _ := GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
} else {
Expand Down Expand Up @@ -355,7 +310,7 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {

// Get another token to test the backend pushing the need to authenticate
// to the handler
jwtToken, _ = getTestJWT(t)
jwtToken, _ = GetTestJWT(t)
if err := ioutil.WriteFile(in, []byte(jwtToken), 0600); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -394,16 +349,3 @@ func testJWTEndToEnd(t *testing.T, ahWrapping bool) {
}
}
}

const (
ecdsaPrivKey string = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49
AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx
hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END EC PRIVATE KEY-----`

ecdsaPubKey string = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS
q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END PUBLIC KEY-----`
)
36 changes: 25 additions & 11 deletions command/agent/sink/sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io/ioutil"
"math/rand"
"os"
"sync/atomic"
"time"

"github.com/hashicorp/errwrap"
Expand Down Expand Up @@ -34,25 +35,30 @@ type SinkConfig struct {
}

type SinkServerConfig struct {
Logger hclog.Logger
Client *api.Client
Context context.Context
Logger hclog.Logger
Client *api.Client
Context context.Context
ExitAfterAuth bool
}

// SinkServer is responsible for pushing tokens to sinks
type SinkServer struct {
DoneCh chan struct{}
logger hclog.Logger
client *api.Client
random *rand.Rand
DoneCh chan struct{}
logger hclog.Logger
client *api.Client
random *rand.Rand
exitAfterAuth bool
remaining *int32
}

func NewSinkServer(conf *SinkServerConfig) *SinkServer {
ss := &SinkServer{
DoneCh: make(chan struct{}),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
DoneCh: make(chan struct{}),
logger: conf.Logger,
client: conf.Client,
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
exitAfterAuth: conf.ExitAfterAuth,
remaining: new(int32),
}

return ss
Expand Down Expand Up @@ -86,6 +92,7 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
for {
select {
case <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
default:
break drainLoop
}
Expand Down Expand Up @@ -116,11 +123,13 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
return currSink.WriteToken(currToken)
}
}
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkFunc(s, token)
}
}

case sinkFunc := <-sinkCh:
atomic.AddInt32(ss.remaining, -1)
select {
case <-ctx.Done():
return
Expand All @@ -134,8 +143,13 @@ func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*Si
case <-ctx.Done():
return
case <-time.After(backoff):
atomic.AddInt32(ss.remaining, 1)
sinkCh <- sinkFunc
}
} else {
if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
return
}
}
}
}
Expand Down
65 changes: 65 additions & 0 deletions command/agent/testing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package agent

import (
"crypto/ecdsa"
"crypto/x509"
"encoding/pem"
"testing"
"time"

jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
)

func GetTestJWT(t *testing.T) (string, *ecdsa.PrivateKey) {
t.Helper()
cl := jwt.Claims{
Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",
Issuer: "https://team-vault.auth0.com/",
NotBefore: jwt.NewNumericDate(time.Now().Add(-5 * time.Second)),
Audience: jwt.Audience{"https://vault.plugin.auth.jwt.test"},
}

privateCl := struct {
User string `json:"https://vault/user"`
Groups []string `json:"https://vault/groups"`
}{
"jeff",
[]string{"foo", "bar"},
}

var key *ecdsa.PrivateKey
block, _ := pem.Decode([]byte(TestECDSAPrivKey))
if block != nil {
var err error
key, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
t.Fatal(err)
}
}

sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
t.Fatal(err)
}

raw, err := jwt.Signed(sig).Claims(cl).Claims(privateCl).CompactSerialize()
if err != nil {
t.Fatal(err)
}

return raw, key
}

const (
TestECDSAPrivKey string = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKfldwWLPYsHjRL9EVTsjSbzTtcGRu6icohNfIqcb6A+oAoGCCqGSM49
AwEHoUQDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbSq+7+1q9BFxAkzjgKnlkXk5qx
hzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END EC PRIVATE KEY-----`

TestECDSAPubKey string = `-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE4+SFvPwOy0miy/FiTT05HnwjpEbS
q+7+1q9BFxAkzjgKnlkXk5qxhzXQvRmS4w9ZsskoTZtuUI+XX7conJhzCQ==
-----END PUBLIC KEY-----`
)
Loading