Skip to content

Commit

Permalink
Unify token updating code (#3442)
Browse files Browse the repository at this point in the history
* bump fly-go

* unified helper for doing housekeeping on tokens

* replace disparate token updating code with unified config.MonitorTokens

* RPC for sending tokens to agent

bug fixes (thanks Tim)
  • Loading branch information
btoews committed Apr 15, 2024
1 parent 201c782 commit 22fb63e
Show file tree
Hide file tree
Showing 13 changed files with 741 additions and 472 deletions.
48 changes: 44 additions & 4 deletions agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/superfly/flyctl/agent/internal/proto"
"github.com/superfly/flyctl/gql"
"github.com/superfly/flyctl/internal/buildinfo"
"github.com/superfly/flyctl/internal/config"
"github.com/superfly/flyctl/internal/flag"
"github.com/superfly/flyctl/internal/logger"
"github.com/superfly/flyctl/internal/sentry"
Expand Down Expand Up @@ -127,14 +128,53 @@ const (
)

type Client struct {
network string
address string
dialer net.Dialer
network string
address string
dialer net.Dialer
agentRefusedTokens bool
}

var errDone = errors.New("done")

func (c *Client) do(parent context.Context, fn func(net.Conn) error) (err error) {
func (c *Client) do(ctx context.Context, fn func(net.Conn) error) (err error) {
if c.agentRefusedTokens {
return c.doNoTokens(ctx, fn)
}

toks := config.Tokens(ctx)
if toks.Empty() {
return c.doNoTokens(ctx, fn)
}

var tokArgs []string
if file := toks.FromFile(); file != "" {
tokArgs = append(tokArgs, "cfg", file)
} else {
tokArgs = append(tokArgs, "str", toks.All())
}

return c.doNoTokens(ctx, func(conn net.Conn) error {
if err := proto.Write(conn, "set-token", tokArgs...); err != nil {
return err
}

data, err := proto.Read(conn)

switch {
case err == nil && string(data) == "ok":
return fn(conn)
case err != nil:
return err
case isError(data):
c.agentRefusedTokens = true
return c.do(ctx, fn)
default:
return err
}
})
}

func (c *Client) doNoTokens(parent context.Context, fn func(net.Conn) error) (err error) {
var conn net.Conn
if conn, err = c.dialContext(parent); err != nil {
return err
Expand Down
132 changes: 43 additions & 89 deletions agent/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/superfly/flyctl/agent"
"github.com/superfly/flyctl/internal/config"
"github.com/superfly/flyctl/internal/env"
"github.com/superfly/flyctl/internal/flyutil"
"github.com/superfly/flyctl/internal/sentry"
"github.com/superfly/flyctl/internal/wireguard"
"github.com/superfly/flyctl/wg"
Expand All @@ -27,7 +28,6 @@ import (
type Options struct {
Socket string
Logger *log.Logger
Client *fly.Client
Background bool
ConfigFile string
ConfigWebsockets bool
Expand All @@ -51,11 +51,19 @@ func Run(ctx context.Context, opt Options) (err error) {
return
}

toks := config.Tokens(ctx)

monitorCtx, cancelMonitor := context.WithCancel(ctx)
config.MonitorTokens(monitorCtx, toks, nil)

err = (&server{
Options: opt,
listener: l,
currentChange: latestChangeAt,
tunnels: make(map[tunnelKey]*wg.Tunnel),
Options: opt,
listener: l,
runCtx: ctx,
currentChange: latestChangeAt,
tunnels: make(map[tunnelKey]*wg.Tunnel),
tokens: toks,
cancelTokenMonitoring: cancelMonitor,
}).serve(ctx, l)

return
Expand Down Expand Up @@ -112,9 +120,12 @@ type server struct {

listener net.Listener

mu sync.Mutex
currentChange time.Time
tunnels map[tunnelKey]*wg.Tunnel
runCtx context.Context
mu sync.Mutex
currentChange time.Time
tunnels map[tunnelKey]*wg.Tunnel
tokens *tokens.Tokens
cancelTokenMonitoring func()
}

type terminateError struct{ error }
Expand Down Expand Up @@ -142,20 +153,6 @@ func (s *server) serve(parent context.Context, l net.Listener) (err error) {
return nil
})

if toks := config.Tokens(ctx); len(toks.MacaroonTokens) != 0 {
eg.Go(func() error {
if f := toks.FromConfigFile; f == "" {
s.print("monitoring for token expiration")
s.updateMacaroonsInMemory(ctx)
} else {
s.print("monitoring for token changes and expiration")
s.updateMacaroonsInFile(ctx, f)
}

return nil
})
}

eg.Go(func() (err error) {
s.printf("OK %d", os.Getpid())
defer s.print("QUIT")
Expand Down Expand Up @@ -227,7 +224,7 @@ func (s *server) checkForConfigChange() (err error) {
return
}

func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle bool, network string) (tunnel *wg.Tunnel, err error) {
func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle bool, network string, client *fly.Client) (tunnel *wg.Tunnel, err error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -240,7 +237,7 @@ func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle
}

var state *wg.WireGuardState
if state, err = wireguard.StateForOrg(ctx, s.Client, org, os.Getenv("FLY_AGENT_WG_REGION"), "", recycle, network); err != nil {
if state, err = wireguard.StateForOrg(ctx, client, org, os.Getenv("FLY_AGENT_WG_REGION"), "", recycle, network); err != nil {
return
}

Expand Down Expand Up @@ -381,7 +378,7 @@ func (s *server) clean(ctx context.Context) {
break
}

if err := wireguard.PruneInvalidPeers(ctx, s.Client); err != nil {
if err := wireguard.PruneInvalidPeers(ctx, s.GetClient(ctx)); err != nil {
s.printf("failed pruning invalid peers: %v", err)
}

Expand All @@ -393,74 +390,35 @@ func (s *server) clean(ctx context.Context) {
}
}

// updateMacaroons prunes expired macaroons and attempts to fetch discharge
// tokens as necessary.
func (s *server) updateMacaroonsInMemory(ctx context.Context) {
toks := config.Tokens(ctx)

ticker := time.NewTicker(time.Minute)
defer ticker.Stop()

var lastErr error

for {
if _, err := toks.Update(ctx, tokens.WithDebugger(s)); err != nil && err != lastErr {
s.print("failed upgrading authentication tokens:", err)
lastErr = err
}
// GetClient returns an API client that uses the server's tokens. Sessions may
// have their own tokens, so should use session.getClient instead.
func (s *server) GetClient(ctx context.Context) *fly.Client {
s.mu.Lock()
defer s.mu.Unlock()

select {
case <-ticker.C:
case <-ctx.Done():
return
}
}
return flyutil.NewClientFromOptions(ctx, fly.ClientOptions{Tokens: s.tokens})
}

// updateMacaroons prunes expired tokens and fetches discharge tokens as
// necessary. Those updates are written back to the config file.
func (s *server) updateMacaroonsInFile(ctx context.Context, path string) {
configToks := config.Tokens(ctx)

ticker := time.NewTicker(time.Minute)
defer ticker.Stop()

var lastErr error

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
}
// UpdateTokensFromClient replaces the server's tokens with those from the
// client if the new ones seem better. Specifically, if the agent was started
// with `FLY_API_TOKEN`, but a later client is using tokens form a config file.
func (s *server) UpdateTokensFromClient(t *tokens.Tokens) {
s.mu.Lock()
defer s.mu.Unlock()

// the tokens in the config are continually updated as the config file
// changes. We do our updates on a copy of the tokens so we can still
// tell if the tokens in the config changed out from under us.
configToksBefore := configToks.All()
localToks := tokens.Parse(configToksBefore)
if s.tokens.FromFile() != "" || t.FromFile() == "" {
return
}

updated, err := localToks.Update(ctx, tokens.WithDebugger(s))
if err != nil && err != lastErr {
s.print("failed upgrading authentication tokens:", err)
lastErr = err
s.print("received new tokens from client")

// Don't continue loop here! It might only be partial failure
}
s.cancelTokenMonitoring()

// the consequences of a race here (agent and foreground command both
// fetching updates simultaneously) are low, so don't bother with a lock
// file.
if updated && configToks.All() == configToksBefore {
if err := config.SetAccessToken(path, localToks.All()); err != nil {
s.print("Failed to persist authentication token:", err)
s.updateMacaroonsInMemory(ctx)
return
}
monitorCtx, cancelMonitor := context.WithCancel(s.runCtx)
config.MonitorTokens(monitorCtx, t, nil)

s.print("Authentication tokens upgraded")
}
}
s.tokens = t
s.cancelTokenMonitoring = cancelMonitor
}

func (s *server) print(v ...interface{}) {
Expand All @@ -470,7 +428,3 @@ func (s *server) print(v ...interface{}) {
func (s *server) printf(format string, v ...interface{}) {
s.Logger.Printf(format, v...)
}

func (s *server) Debug(v ...any) {
s.Logger.Print(v...)
}
Loading

0 comments on commit 22fb63e

Please sign in to comment.