diff --git a/agent/client.go b/agent/client.go index 9d0368d07c..7ec491c95e 100644 --- a/agent/client.go +++ b/agent/client.go @@ -22,6 +22,7 @@ import ( "github.com/superfly/flyctl/agent/internal/proto" "github.com/superfly/flyctl/gql" "github.com/superfly/flyctl/internal/buildinfo" + "github.com/superfly/flyctl/internal/config" "github.com/superfly/flyctl/internal/flag" "github.com/superfly/flyctl/internal/logger" "github.com/superfly/flyctl/internal/sentry" @@ -127,14 +128,53 @@ const ( ) type Client struct { - network string - address string - dialer net.Dialer + network string + address string + dialer net.Dialer + agentRefusedTokens bool } var errDone = errors.New("done") -func (c *Client) do(parent context.Context, fn func(net.Conn) error) (err error) { +func (c *Client) do(ctx context.Context, fn func(net.Conn) error) (err error) { + if c.agentRefusedTokens { + return c.doNoTokens(ctx, fn) + } + + toks := config.Tokens(ctx) + if toks.Empty() { + return c.doNoTokens(ctx, fn) + } + + var tokArgs []string + if file := toks.FromFile(); file != "" { + tokArgs = append(tokArgs, "cfg", file) + } else { + tokArgs = append(tokArgs, "str", toks.All()) + } + + return c.doNoTokens(ctx, func(conn net.Conn) error { + if err := proto.Write(conn, "set-token", tokArgs...); err != nil { + return err + } + + data, err := proto.Read(conn) + + switch { + case err == nil && string(data) == "ok": + return fn(conn) + case err != nil: + return err + case isError(data): + c.agentRefusedTokens = true + return c.do(ctx, fn) + default: + return err + } + }) +} + +func (c *Client) doNoTokens(parent context.Context, fn func(net.Conn) error) (err error) { var conn net.Conn if conn, err = c.dialContext(parent); err != nil { return err diff --git a/agent/server/server.go b/agent/server/server.go index baaaeeedfd..4e60c0e790 100644 --- a/agent/server/server.go +++ b/agent/server/server.go @@ -18,6 +18,7 @@ import ( "github.com/superfly/flyctl/agent" "github.com/superfly/flyctl/internal/config" "github.com/superfly/flyctl/internal/env" + "github.com/superfly/flyctl/internal/flyutil" "github.com/superfly/flyctl/internal/sentry" "github.com/superfly/flyctl/internal/wireguard" "github.com/superfly/flyctl/wg" @@ -27,7 +28,6 @@ import ( type Options struct { Socket string Logger *log.Logger - Client *fly.Client Background bool ConfigFile string ConfigWebsockets bool @@ -51,11 +51,19 @@ func Run(ctx context.Context, opt Options) (err error) { return } + toks := config.Tokens(ctx) + + monitorCtx, cancelMonitor := context.WithCancel(ctx) + config.MonitorTokens(monitorCtx, toks, nil) + err = (&server{ - Options: opt, - listener: l, - currentChange: latestChangeAt, - tunnels: make(map[tunnelKey]*wg.Tunnel), + Options: opt, + listener: l, + runCtx: ctx, + currentChange: latestChangeAt, + tunnels: make(map[tunnelKey]*wg.Tunnel), + tokens: toks, + cancelTokenMonitoring: cancelMonitor, }).serve(ctx, l) return @@ -112,9 +120,12 @@ type server struct { listener net.Listener - mu sync.Mutex - currentChange time.Time - tunnels map[tunnelKey]*wg.Tunnel + runCtx context.Context + mu sync.Mutex + currentChange time.Time + tunnels map[tunnelKey]*wg.Tunnel + tokens *tokens.Tokens + cancelTokenMonitoring func() } type terminateError struct{ error } @@ -142,20 +153,6 @@ func (s *server) serve(parent context.Context, l net.Listener) (err error) { return nil }) - if toks := config.Tokens(ctx); len(toks.MacaroonTokens) != 0 { - eg.Go(func() error { - if f := toks.FromConfigFile; f == "" { - s.print("monitoring for token expiration") - s.updateMacaroonsInMemory(ctx) - } else { - s.print("monitoring for token changes and expiration") - s.updateMacaroonsInFile(ctx, f) - } - - return nil - }) - } - eg.Go(func() (err error) { s.printf("OK %d", os.Getpid()) defer s.print("QUIT") @@ -227,7 +224,7 @@ func (s *server) checkForConfigChange() (err error) { return } -func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle bool, network string) (tunnel *wg.Tunnel, err error) { +func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle bool, network string, client *fly.Client) (tunnel *wg.Tunnel, err error) { s.mu.Lock() defer s.mu.Unlock() @@ -240,7 +237,7 @@ func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle } var state *wg.WireGuardState - if state, err = wireguard.StateForOrg(ctx, s.Client, org, os.Getenv("FLY_AGENT_WG_REGION"), "", recycle, network); err != nil { + if state, err = wireguard.StateForOrg(ctx, client, org, os.Getenv("FLY_AGENT_WG_REGION"), "", recycle, network); err != nil { return } @@ -381,7 +378,7 @@ func (s *server) clean(ctx context.Context) { break } - if err := wireguard.PruneInvalidPeers(ctx, s.Client); err != nil { + if err := wireguard.PruneInvalidPeers(ctx, s.GetClient(ctx)); err != nil { s.printf("failed pruning invalid peers: %v", err) } @@ -393,74 +390,35 @@ func (s *server) clean(ctx context.Context) { } } -// updateMacaroons prunes expired macaroons and attempts to fetch discharge -// tokens as necessary. -func (s *server) updateMacaroonsInMemory(ctx context.Context) { - toks := config.Tokens(ctx) - - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - - var lastErr error - - for { - if _, err := toks.Update(ctx, tokens.WithDebugger(s)); err != nil && err != lastErr { - s.print("failed upgrading authentication tokens:", err) - lastErr = err - } +// GetClient returns an API client that uses the server's tokens. Sessions may +// have their own tokens, so should use session.getClient instead. +func (s *server) GetClient(ctx context.Context) *fly.Client { + s.mu.Lock() + defer s.mu.Unlock() - select { - case <-ticker.C: - case <-ctx.Done(): - return - } - } + return flyutil.NewClientFromOptions(ctx, fly.ClientOptions{Tokens: s.tokens}) } -// updateMacaroons prunes expired tokens and fetches discharge tokens as -// necessary. Those updates are written back to the config file. -func (s *server) updateMacaroonsInFile(ctx context.Context, path string) { - configToks := config.Tokens(ctx) - - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - - var lastErr error - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - } +// UpdateTokensFromClient replaces the server's tokens with those from the +// client if the new ones seem better. Specifically, if the agent was started +// with `FLY_API_TOKEN`, but a later client is using tokens form a config file. +func (s *server) UpdateTokensFromClient(t *tokens.Tokens) { + s.mu.Lock() + defer s.mu.Unlock() - // the tokens in the config are continually updated as the config file - // changes. We do our updates on a copy of the tokens so we can still - // tell if the tokens in the config changed out from under us. - configToksBefore := configToks.All() - localToks := tokens.Parse(configToksBefore) + if s.tokens.FromFile() != "" || t.FromFile() == "" { + return + } - updated, err := localToks.Update(ctx, tokens.WithDebugger(s)) - if err != nil && err != lastErr { - s.print("failed upgrading authentication tokens:", err) - lastErr = err + s.print("received new tokens from client") - // Don't continue loop here! It might only be partial failure - } + s.cancelTokenMonitoring() - // the consequences of a race here (agent and foreground command both - // fetching updates simultaneously) are low, so don't bother with a lock - // file. - if updated && configToks.All() == configToksBefore { - if err := config.SetAccessToken(path, localToks.All()); err != nil { - s.print("Failed to persist authentication token:", err) - s.updateMacaroonsInMemory(ctx) - return - } + monitorCtx, cancelMonitor := context.WithCancel(s.runCtx) + config.MonitorTokens(monitorCtx, t, nil) - s.print("Authentication tokens upgraded") - } - } + s.tokens = t + s.cancelTokenMonitoring = cancelMonitor } func (s *server) print(v ...interface{}) { @@ -470,7 +428,3 @@ func (s *server) print(v ...interface{}) { func (s *server) printf(format string, v ...interface{}) { s.Logger.Printf(format, v...) } - -func (s *server) Debug(v ...any) { - s.Logger.Print(v...) -} diff --git a/agent/server/session.go b/agent/server/session.go index 354f9b0af3..f5fecdd1fd 100644 --- a/agent/server/session.go +++ b/agent/server/session.go @@ -20,11 +20,14 @@ import ( "golang.org/x/sync/errgroup" fly "github.com/superfly/fly-go" + "github.com/superfly/fly-go/tokens" "github.com/superfly/flyctl/agent" "github.com/superfly/flyctl/agent/internal/proto" "github.com/superfly/flyctl/wg" "github.com/superfly/flyctl/internal/buildinfo" + "github.com/superfly/flyctl/internal/config" + "github.com/superfly/flyctl/internal/flyutil" ) type id uint64 @@ -38,6 +41,7 @@ type session struct { conn net.Conn logger *log.Logger id id + tokens *tokens.Tokens } var errUnsupportedCommand = errors.New("unsupported command") @@ -81,6 +85,10 @@ func runSession(ctx context.Context, srv *server, conn net.Conn, id id) { return } + s.runCommand(ctx) +} + +func (s *session) runCommand(ctx context.Context) { buf, err := proto.Read(s.conn) if len(buf) > 0 { s.logger.Printf("<- (% 5d) %q", len(buf), redact(buf)) @@ -95,30 +103,37 @@ func runSession(ctx context.Context, srv *server, conn net.Conn, id id) { } args := strings.Split(string(buf), " ") - - fn := handlers[args[0]] - if fn == nil { + var handler func(*session, context.Context, ...string) + + switch args[0] { + case "kill": + handler = (*session).kill + case "ping": + handler = (*session).ping + case "establish": + handler = (*session).establish + case "reestablish": + handler = (*session).reestablish + case "connect": + handler = (*session).connect + case "probe": + handler = (*session).probe + case "instances": + handler = (*session).instances + case "resolve": + handler = (*session).resolve + case "lookupTxt": + handler = (*session).lookupTxt + case "ping6": + handler = (*session).ping6 + case "set-token": + handler = (*session).setToken + default: s.error(errUnsupportedCommand) - return } - fn(s, ctx, args[1:]...) -} - -type handlerFunc func(*session, context.Context, ...string) - -var handlers = map[string]handlerFunc{ - "kill": (*session).kill, - "ping": (*session).ping, - "establish": (*session).establish, - "reestablish": (*session).reestablish, - "connect": (*session).connect, - "probe": (*session).probe, - "instances": (*session).instances, - "resolve": (*session).resolve, - "lookupTxt": (*session).lookupTxt, - "ping6": (*session).ping6, + handler(s, ctx, args[1:]...) } var errMalformedKill = errors.New("malformed kill command") @@ -162,7 +177,7 @@ func (s *session) doEstablish(ctx context.Context, recycle bool, args ...string) return } - tunnel, err := s.srv.buildTunnel(ctx, org, recycle, args[1]) + tunnel, err := s.srv.buildTunnel(ctx, org, recycle, args[1], s.getClient(ctx)) if err != nil { s.error(err) @@ -186,7 +201,7 @@ func (s *session) reestablish(ctx context.Context, args ...string) { var errNoSuchOrg = errors.New("no such organization") func (s *session) fetchOrg(ctx context.Context, slug string) (*fly.Organization, error) { - orgs, err := s.srv.Client.GetOrganizations(ctx) + orgs, err := s.getClient(ctx).GetOrganizations(ctx) if err != nil { return nil, err } @@ -534,6 +549,44 @@ func (s *session) ping6(ctx context.Context, args ...string) { } } +var errMalformedSetToken = errors.New("malformed set-token command") + +// setToken instructs the agent which tokens to use for API calls. +func (s *session) setToken(ctx context.Context, args ...string) { + if !s.exactArgs(2, args, errMalformedSetToken) { + return + } + + switch args[0] { + case "cfg": + tokStr, err := config.ReadAccessToken(args[1]) + if err != nil { + s.error(err) + return + } + + s.tokens = tokens.ParseFromFile(tokStr, args[1]) + case "str": + s.tokens = tokens.Parse(args[1]) + } + + go s.srv.UpdateTokensFromClient(s.tokens) + + s.ok() + + s.runCommand(ctx) +} + +// getClient returns an API client that uses any API tokens sent by the client. +// If none have been sent, it falls back to using the server's tokens. +func (s *session) getClient(ctx context.Context) *fly.Client { + if s.tokens == nil { + return s.srv.GetClient(ctx) + } + + return flyutil.NewClientFromOptions(ctx, fly.ClientOptions{Tokens: s.tokens}) +} + func (s *session) error(err error) bool { return s.reply("err", err.Error()) } @@ -607,8 +660,14 @@ func isClosed(err error) bool { return errors.Is(err, net.ErrClosed) } -var redactRx = regexp.MustCompile(`(PrivateKey|private)":".*?"`) +var ( + redactPrivateKeyRx = regexp.MustCompile(`(PrivateKey|private)":".*?"`) + redactTokenRx = regexp.MustCompile(`(fo1_|fm1[ar]_|fm2_)[a-zA-Z0-9/+_-]+=*`) +) func redact(buf []byte) []byte { - return redactRx.ReplaceAll(buf, []byte(`PrivateKey":"[redacted]"`)) + buf = redactPrivateKeyRx.ReplaceAll(buf, []byte(`PrivateKey":"[redacted]"`)) + buf = redactTokenRx.ReplaceAll(buf, []byte(`$1[redacted]`)) + return buf + } diff --git a/agent/start.go b/agent/start.go index d75d2e8c9d..3609a56adc 100644 --- a/agent/start.go +++ b/agent/start.go @@ -45,8 +45,8 @@ func StartDaemon(ctx context.Context) (*Client, error) { env = append(env, "FLY_NO_UPDATE_CHECK=1") // if our tokens came from the config file, let agent get them there too - if toks := config.Tokens(ctx); toks.FromConfigFile == "" { - env = append(env, fmt.Sprintf("FLY_API_TOKEN=%s", config.Tokens(ctx).GraphQL())) + if toks := config.Tokens(ctx); toks.FromFile() == "" { + env = append(env, fmt.Sprintf("FLY_API_TOKEN=%s", config.Tokens(ctx).All())) } cmd.Env = env diff --git a/go.mod b/go.mod index 56323d2c46..2893e399b1 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,6 @@ require ( github.com/docker/go-units v0.5.0 github.com/dustin/go-humanize v1.0.1 github.com/ejcx/sshcert v1.1.0 - github.com/fsnotify/fsnotify v1.7.0 github.com/getsentry/sentry-go v0.27.0 github.com/go-logr/logr v1.4.1 github.com/gofrs/flock v0.8.1 @@ -63,7 +62,7 @@ require ( github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 - github.com/superfly/fly-go v0.1.5-0.20240413215828-21ef94bc21dc + github.com/superfly/fly-go v0.1.5-0.20240415152134-3bdfded3a15c github.com/superfly/graphql v0.2.4 github.com/superfly/lfsc-go v0.1.1 github.com/superfly/macaroon v0.2.12 @@ -157,6 +156,7 @@ require ( github.com/emirpasic/gods v1.18.1 // indirect github.com/fatih/color v1.15.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gdamore/encoding v1.0.0 // indirect github.com/gdamore/tcell/v2 v2.7.0 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect diff --git a/go.sum b/go.sum index 33d92e2cf9..8a85051a34 100644 --- a/go.sum +++ b/go.sum @@ -614,8 +614,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/superfly/fly-go v0.1.5-0.20240413215828-21ef94bc21dc h1:Tk254Ggrt4kOZfaq6c3BWUwVuABFYaKXudB61aApniw= -github.com/superfly/fly-go v0.1.5-0.20240413215828-21ef94bc21dc/go.mod h1:MUKMuc5Tg+qmDmAi5jBMKcuitw4SNhrZ60fS4jqIZpQ= +github.com/superfly/fly-go v0.1.5-0.20240415152134-3bdfded3a15c h1:rC5JrtJE5nDzFNhBBl7ru+/nq0gClgwp8HazP2pHk2c= +github.com/superfly/fly-go v0.1.5-0.20240415152134-3bdfded3a15c/go.mod h1:MUKMuc5Tg+qmDmAi5jBMKcuitw4SNhrZ60fS4jqIZpQ= github.com/superfly/graphql v0.2.4 h1:Av8hSk4x8WvKJ6MTnEwrLknSVSGPc7DWpgT3z/kt3PU= github.com/superfly/graphql v0.2.4/go.mod h1:CVfDl31srm8HnJ9udwLu6hFNUW/P6GUM2dKcG1YQ8jc= github.com/superfly/lfsc-go v0.1.1 h1:dGjLgt81D09cG+aR9lJZIdmonjZSR5zYCi7s54+ZU2Q= diff --git a/internal/command/agent/run.go b/internal/command/agent/run.go index 4d3c39f1de..629683e3a6 100644 --- a/internal/command/agent/run.go +++ b/internal/command/agent/run.go @@ -16,6 +16,7 @@ import ( "github.com/superfly/flyctl/flyctl" "github.com/superfly/flyctl/internal/command" + "github.com/superfly/flyctl/internal/config" "github.com/superfly/flyctl/internal/filemu" "github.com/superfly/flyctl/internal/flag" "github.com/superfly/flyctl/internal/state" @@ -27,9 +28,10 @@ func newRun() (cmd *cobra.Command) { long = short + "\n" ) - cmd = command.New("run", short, long, run, - command.RequireSession, - ) + // Don't use RequireSession preparer. It does its own token monitoring and + // will try to run token discharge flows that would involve opening URLs in + // the user's browser. We don't want to do that in a background agent. + cmd = command.New("run", short, long, run) cmd.Args = cobra.MaximumNArgs(1) cmd.Aliases = []string{"daemon-start"} @@ -48,6 +50,11 @@ func run(ctx context.Context) error { } defer closeLogger() + if config.Tokens(ctx).GraphQL() == "" { + logger.Println(fly.ErrNoAuthToken) + return fly.ErrNoAuthToken + } + unlock, err := lock(ctx, logger) if err != nil { return err @@ -57,7 +64,6 @@ func run(ctx context.Context) error { opt := server.Options{ Socket: socketPath(ctx), Logger: logger, - Client: fly.ClientFromContext(ctx), Background: logPath != "", ConfigFile: state.ConfigFile(ctx), ConfigWebsockets: viper.GetBool(flyctl.ConfigWireGuardWebsockets), diff --git a/internal/command/command.go b/internal/command/command.go index cad977e1e2..9351febe85 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -15,7 +15,6 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" fly "github.com/superfly/fly-go" - "github.com/superfly/fly-go/tokens" "github.com/superfly/flyctl/internal/command/auth/webauth" "github.com/superfly/flyctl/internal/prompt" "github.com/superfly/flyctl/iostreams" @@ -551,53 +550,7 @@ func RequireSession(ctx context.Context) (context.Context, error) { } } - return updateMacaroons(ctx) -} - -// updateMacaroons prune any invalid/expired macaroons and fetch needed third -// party discharges -func updateMacaroons(ctx context.Context) (context.Context, error) { - var ( - log = logger.FromContext(ctx) - cfg = config.FromContext(ctx) - toks = cfg.Tokens - ) - - updated, err := toks.Update(ctx, - tokens.WithUserURLCallback(tryOpenUserURL), - tokens.WithDebugger(log), - ) - if err != nil { - log.Warn("Failed to upgrade authentication token. Command may fail.") - log.Debug(err) - } - - if toks.FromConfigFile == "" { - return ctx, nil - } - - if updated { - if err := config.SetAccessToken(toks.FromConfigFile, toks.All()); err != nil { - log.Warn("Failed to persist authentication token.") - log.Debug(err) - } - } - - sub, err := cfg.Watch(ctx) - if err != nil { - log.Warn("Failed to watch config file for changes.") - log.Debug(err) - return ctx, nil - } - - go func() { - for newCfg := range sub { - if cfg.Tokens.All() != newCfg.Tokens.All() { - log.Debug("Authentication tokens updated from config file.") - cfg.Tokens.Replace(newCfg.Tokens) - } - } - }() + config.MonitorTokens(ctx, config.Tokens(ctx), tryOpenUserURL) return ctx, nil } diff --git a/internal/command/tokens/attenuate.go b/internal/command/tokens/attenuate.go index c6711d5f87..61804091a6 100644 --- a/internal/command/tokens/attenuate.go +++ b/internal/command/tokens/attenuate.go @@ -85,7 +85,7 @@ func getPermissionAndDischargeTokens(ctx context.Context) ([]*macaroon.Macaroon, } func getTokens(ctx context.Context) ([][]byte, error) { - token := config.Tokens(ctx).Macaroons() + token := config.Tokens(ctx).MacaroonsOnly().All() if token == "" { return nil, errors.New("pass a macaroon token (e.g. from `fly tokens deploy`) as the -t argument or in FLY_API_TOKEN") diff --git a/internal/config/config.go b/internal/config/config.go index e2e9567c5e..c64a4dca15 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,16 +5,13 @@ import ( "errors" "io/fs" "sync" - "time" - "github.com/fsnotify/fsnotify" "github.com/spf13/pflag" "github.com/superfly/fly-go/tokens" "github.com/superfly/flyctl/internal/env" "github.com/superfly/flyctl/internal/flag/flagctx" "github.com/superfly/flyctl/internal/flag/flagnames" - "github.com/superfly/flyctl/internal/task" ) const ( @@ -54,12 +51,7 @@ const ( // // Instances of Config are safe for concurrent use. type Config struct { - mu sync.RWMutex - path string - - watchOnce sync.Once - watchErr error - subs map[chan *Config]struct{} + mu sync.RWMutex // APIBaseURL denotes the base URL of the API. APIBaseURL string @@ -161,8 +153,6 @@ func (cfg *Config) applyFile(path string) (err error) { cfg.mu.Lock() defer cfg.mu.Unlock() - cfg.path = path - var w struct { AccessToken string `yaml:"access_token"` MetricsToken string `yaml:"metrics_token"` @@ -173,9 +163,7 @@ func (cfg *Config) applyFile(path string) (err error) { w.AutoUpdate = true if err = unmarshal(path, &w); err == nil { - cfg.Tokens = tokens.Parse(w.AccessToken) - cfg.Tokens.FromConfigFile = path - + cfg.Tokens = tokens.ParseFromFile(w.AccessToken, path) cfg.MetricsToken = w.MetricsToken cfg.SendMetrics = w.SendMetrics cfg.AutoUpdate = w.AutoUpdate @@ -214,141 +202,6 @@ func (cfg *Config) MetricsBaseURLIsProduction() bool { return cfg.MetricsBaseURL == defaultMetricsBaseURL } -func (cfg *Config) Watch(ctx context.Context) (chan *Config, error) { - cfg.watchOnce.Do(func() { - watch, err := fsnotify.NewWatcher() - if err != nil { - cfg.watchErr = err - return - } - - if err := watch.Add(cfg.path); err != nil { - cfg.watchErr = err - return - } - - cfg.subs = make(map[chan *Config]struct{}) - - task.FromContext(ctx).Run(func(ctx context.Context) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - cleanupDone := make(chan struct{}) - defer func() { <-cleanupDone }() - - go func() { - defer close(cleanupDone) - - <-ctx.Done() - - cfg.mu.Lock() - defer cfg.mu.Unlock() - - cfg.watchErr = errors.Join(cfg.watchErr, ctx.Err(), watch.Close()) - - for sub := range cfg.subs { - close(sub) - } - cfg.subs = nil - }() - - var ( - notifyCtx context.Context - cancelNotify context.CancelFunc = func() {} - cancelLastNotify *context.CancelFunc = &cancelNotify - ) - defer func() { (*cancelLastNotify)() }() - - for { - select { - case e, open := <-watch.Events: - if !open { - return - } - - if !e.Has(fsnotify.Write) { - continue - } - - // Debounce change notifications: notifySubs sleeps for 50ms - // before notifying subs. If we get another change before - // that, we preempt the previous notification attempt. This - // is necessary because we receive multiple notifications - // for a single config change on windows and the first event - // fires before the change is available to be read. - (*cancelLastNotify)() - notifyCtx, cancelNotify = context.WithCancel(ctx) - cancelLastNotify = &cancelNotify - - go cfg.notifySubs(notifyCtx) - case err := <-watch.Errors: - cfg.mu.Lock() - defer cfg.mu.Unlock() - - cfg.watchErr = errors.Join(cfg.watchErr, err) - - return - case <-ctx.Done(): - return - } - } - }) - }) - - cfg.mu.Lock() - defer cfg.mu.Unlock() - - if cfg.watchErr != nil { - return nil, cfg.watchErr - } - - sub := make(chan *Config) - cfg.subs[sub] = struct{}{} - - return sub, nil -} - -func (cfg *Config) Unwatch(sub chan *Config) { - cfg.mu.Lock() - defer cfg.mu.Unlock() - - if cfg.subs != nil { - delete(cfg.subs, sub) - close(sub) - } -} - -func (cfg *Config) notifySubs(ctx context.Context) { - // sleep for 50ms to facilitate debouncing (described above) - select { - case <-ctx.Done(): - return - case <-time.After(50 * time.Millisecond): - } - - newCfg, err := Load(ctx, cfg.path) - if err != nil { - return - } - - cfg.mu.RLock() - defer cfg.mu.RUnlock() - - // just in case we have a slow subscriber - timer := time.NewTimer(100 * time.Millisecond) - defer timer.Stop() - - for sub := range cfg.subs { - select { - case sub <- newCfg: - case <-timer.C: - return - case <-ctx.Done(): - return - } - } -} - func applyStringFlags(fs *pflag.FlagSet, flags map[string]*string) { for name, dst := range flags { if !fs.Changed(name) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go deleted file mode 100644 index 1eb593794c..0000000000 --- a/internal/config/config_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package config - -import ( - "context" - "errors" - "io" - "os" - "path" - "sync" - "testing" - "time" - - "github.com/spf13/pflag" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/superfly/flyctl/flyctl" - "github.com/superfly/flyctl/internal/flag/flagctx" - "github.com/superfly/flyctl/internal/logger" - "github.com/superfly/flyctl/internal/task" -) - -func TestConfigWatch(t *testing.T) { - cfgDirWas, cfgDirWasSet := os.LookupEnv("FLY_CONFIG_DIR") - os.Setenv("FLY_CONFIG_DIR", t.TempDir()) - flyctl.InitConfig() - t.Cleanup(func() { - if cfgDirWasSet { - os.Setenv("FLY_CONFIG_DIR", cfgDirWas) - } else { - os.Unsetenv("FLY_CONFIG_DIR") - } - }) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ctx = logger.NewContext(ctx, logger.New(io.Discard, logger.Error, false)) - ctx = flagctx.NewContext(ctx, new(pflag.FlagSet)) - - tm := task.New() - tm.Start(ctx) - ctx = task.WithContext(ctx, tm) - - path := path.Join(t.TempDir(), "config.yml") - - require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_foo`), 0644)) - cfg, err := Load(ctx, path) - require.NoError(t, err) - require.Equal(t, "fo1_foo", cfg.Tokens.All()) - - c1, err := cfg.Watch(ctx) - require.NoError(t, err) - - c2, err := cfg.Watch(ctx) - require.NoError(t, err) - - cfgs, errs := getConfigChanges(c1, c2) - require.Equal(t, 2, len(errs)) - require.Equal(t, 0, len(cfgs)) - - require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_bar`), 0644)) - - cfgs, errs = getConfigChanges(c1, c2) - require.Equal(t, 0, len(errs), errs) - require.Equal(t, 2, len(cfgs)) - require.Equal(t, cfgs[0], cfgs[1]) - require.Equal(t, "fo1_bar", cfgs[0].Tokens.All()) - - // debouncing - require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_aaa`), 0644)) - require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_bbb`), 0644)) - - cfgs, errs = getConfigChanges(c1, c2) - require.Equal(t, 0, len(errs)) - require.Equal(t, 2, len(cfgs)) - require.Equal(t, cfgs[0], cfgs[1]) - require.Equal(t, "fo1_bbb", cfgs[0].Tokens.All()) - - cfgs, errs = getConfigChanges(c1, c2) - require.Equal(t, 2, len(errs)) - require.Equal(t, 0, len(cfgs)) - - cfg.Unwatch(c1) - - require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_baz`), 0644)) - - cfgs, errs = getConfigChanges(c2) - require.Equal(t, 0, len(errs)) - require.Equal(t, 1, len(cfgs)) - require.Equal(t, "fo1_baz", cfgs[0].Tokens.All()) - - shutdown := make(chan struct{}) - go func() { - defer close(shutdown) - tm.Shutdown() - }() - select { - case <-shutdown: - case <-time.After(50 * time.Millisecond): - t.Fatal("slow shutdown") - } - - _, open := <-c1 - require.False(t, open) - _, open = <-c2 - require.False(t, open) - - _, err = cfg.Watch(ctx) - assert.Error(t, err) - require.EqualError(t, err, context.Canceled.Error()) -} - -func getConfigChanges(chans ...chan *Config) ([]*Config, []error) { - var ( - cfgs []*Config - errs []error - m sync.Mutex - wg sync.WaitGroup - ) - - for _, ch := range chans { - ch := ch - - wg.Add(1) - go func() { - defer wg.Done() - defer m.Unlock() - - select { - case cfg, open := <-ch: - m.Lock() - if open { - cfgs = append(cfgs, cfg) - } else { - errs = append(errs, errors.New("closed")) - } - case <-time.After(100 * time.Millisecond): - m.Lock() - errs = append(errs, errors.New("timeout")) - } - }() - } - - wg.Wait() - - return cfgs, errs -} diff --git a/internal/config/tokens.go b/internal/config/tokens.go new file mode 100644 index 0000000000..1aa4a05352 --- /dev/null +++ b/internal/config/tokens.go @@ -0,0 +1,353 @@ +package config + +import ( + "context" + "errors" + "fmt" + "slices" + "strconv" + "sync" + "time" + + "github.com/superfly/fly-go" + "github.com/superfly/fly-go/tokens" + "github.com/superfly/flyctl/gql" + "github.com/superfly/flyctl/internal/flyutil" + "github.com/superfly/flyctl/internal/logger" + "github.com/superfly/flyctl/internal/task" + "github.com/superfly/macaroon" + "github.com/superfly/macaroon/flyio" + "golang.org/x/exp/maps" +) + +// UserURLCallback is a function that opens a URL in the user's browser. This is +// used for token discharge flows that require user interaction. +type UserURLCallback func(ctx context.Context, url string) error + +// MonitorTokens does some housekeeping on the provided tokens. Then, in a +// goroutine, it continues to keep the tokens updated and fresh. The call to +// MonitorTokens will return as soon as the tokens are ready for use and the +// background job will run until the context is cancelled. Token updates include +// - Keeping the tokens synced with the config file. +// - Refreshing any expired discharge tokens. +// - Pruning expired or invalid token. +// - Fetching macaroons for any organizations the user has been added to. +// - Pruning tokens for organizations the user is no longer a member of. +func MonitorTokens(monitorCtx context.Context, t *tokens.Tokens, uucb UserURLCallback) { + log := logger.FromContext(monitorCtx) + file := t.FromFile() + + updated1, err := fetchOrgTokens(monitorCtx, t) + if err != nil { + log.Debugf("failed to fetch missing tokens org tokens: %s", err) + } + + updated2, err := refreshDischargeTokens(monitorCtx, t, uucb) + if err != nil { + log.Debugf("failed to update discharge tokens: %s", err) + + } + + if file != "" && updated1 || updated2 { + if err := SetAccessToken(file, t.All()); err != nil { + log.Debugf("failed to persist updated tokens: %s", err) + } + } + + task.FromContext(monitorCtx).Run(func(taskCtx context.Context) { + taskCtx, cancelTask := context.WithCancel(taskCtx) + + var m sync.Mutex + var wg sync.WaitGroup + + wg.Add(2) + + if file != "" { + log.Debugf("monitoring tokens at %s", file) + } else { + log.Debug("monitoring tokens in memory") + } + + go monitorConfigTokenChanges(taskCtx, &m, t, wg.Done) + go keepConfigTokensFresh(taskCtx, &m, t, uucb, wg.Done) + + // shut down when the task manager is shutting down or when the + // ctx passed into MonitorTokens is cancelled. + select { + case <-taskCtx.Done(): + case <-monitorCtx.Done(): + } + + log.Debug("done monitoring tokens") + cancelTask() + wg.Wait() + }) +} + +// monitorConfigTokenChanges watches for token changes in the config file. This can +// happen if a foreground process updates the config file while the agent is +// running. +func monitorConfigTokenChanges(ctx context.Context, m *sync.Mutex, t *tokens.Tokens, done func()) error { + defer done() + + file := t.FromFile() + if file == "" { + return nil + } + + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + currentStr, err := ReadAccessToken(file) + if err != nil { + return err + } + + current := tokens.ParseFromFile(currentStr, file) + + m.Lock() + t.Replace(current) + m.Unlock() + } + } +} + +// keepConfigTokensFresh periodically updates our tokens and syncs those to the config +// file. +func keepConfigTokensFresh(ctx context.Context, m *sync.Mutex, t *tokens.Tokens, uucb UserURLCallback, done func()) error { + defer done() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + logger := logger.FromContext(ctx) + file := t.FromFile() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + localCopy := t.Copy() + beforeUpdate := t.Copy() + + updated1, err := fetchOrgTokens(ctx, localCopy) + if err != nil { + logger.Debugf("failed to fetch missing org tokens: %s", err) + // don't continue. might have been partial success + } + + updated2, err := refreshDischargeTokens(ctx, localCopy, uucb) + if err != nil { + logger.Debugf("failed to update discharge tokens: %s", err) + // don't continue. might have been partial success + } + + if !updated2 && !updated1 { + continue + } + + m.Lock() + // don't clobber config file if it changed out from under us. the + // consequences of a race here (agent and foreground command both + // fetching updates simultaneously) are low, so don't bother with an + // extra lock file. + if beforeUpdate.Equal(t) { + t.Replace(localCopy) + + if file != "" { + if err := SetAccessToken(file, t.All()); err != nil { + logger.Debugf("failed to persist updated tokens: %s", err) + + // don't try again if we fail to write once + file = "" + } + } + } + m.Unlock() + } + } +} + +// refreshDischargeTokens attempts to refresh any expired discharge tokens. It +// returns true if any tokens were updated, which might be the case even if +// there was an error for other tokens. +// +// Some discharges may require user interaction in the form of opening a URL in +// the user's browser. Set the UserURLCallback package variable if you want to +// support this. +// +// Don't call this when other goroutines might also be accessing t. +func refreshDischargeTokens(ctx context.Context, t *tokens.Tokens, uucb UserURLCallback) (bool, error) { + updateOpts := []tokens.UpdateOption{tokens.WithDebugger(logger.FromContext(ctx))} + + if uucb != nil { + updateOpts = append(updateOpts, tokens.WithUserURLCallback(uucb)) + } + + return t.Update(ctx, updateOpts...) +} + +// fetchOrgTokens checks that we macaroons for all orgs the user is a member of. +// It returns true if any new tokens were added, which might be the case even if +// there was an error. +// +// Don't call this when other goroutines might also be accessing t. +func fetchOrgTokens(ctx context.Context, t *tokens.Tokens) (bool, error) { + return doFetchOrgTokens(ctx, t, defaultOrgFetcher, defaultTokenMinter) +} + +func doFetchOrgTokens(ctx context.Context, t *tokens.Tokens, fetchOrgs orgFetcher, mintToken tokenMinter) (bool, error) { + macToks := t.GetMacaroonTokens() + + if len(macToks) == 0 || len(t.GetUserTokens()) == 0 { + return false, nil + } + + c := flyutil.NewClientFromOptions(ctx, fly.ClientOptions{Tokens: t.UserTokenOnly()}) + + graphIDByNumericID, err := fetchOrgs(ctx, c) + if err != nil { + return false, err + } + + log := logger.FromContext(ctx) + + tokOIDS := make(map[uint64]bool, len(macToks)) + macToks = slices.DeleteFunc(macToks, func(tok string) bool { + toks, err := macaroon.Parse(tok) + if err != nil { + log.Debugf("pruning token: failed to parse macaroon: %v", err) + return true + } + + permMacs, _, _, _, err := macaroon.FindPermissionAndDischargeTokens(toks, flyio.LocationPermission) + if err != nil { + log.Debugf("pruning token: failed to find permission token: %v", err) + return true + } + + // discharge token? + if len(permMacs) != 1 { + return false + } + + oid, err := flyio.OrganizationScope(&permMacs[0].UnsafeCaveats) + if err != nil { + log.Debugf("pruning token: failed to calculate org scope: %v", err) + return true + } + + if _, hasOrg := graphIDByNumericID[oid]; !hasOrg { + log.Debug("pruning token: not in org") + return true + } + + tokOIDS[oid] = true + return false + }) + + // find missing orgs by deleting the ones we found + for oid := range tokOIDS { + delete(graphIDByNumericID, oid) + } + + var ( + wg sync.WaitGroup + wgErr error + wgLock sync.Mutex + ) + + addErr := func(err error) { + wgLock.Lock() + defer wgLock.Unlock() + wgErr = errors.Join(wgErr, err) + } + addMac := func(m string) { + wgLock.Lock() + defer wgLock.Unlock() + macToks = append(macToks, m) + } + for _, graphID := range maps.Values(graphIDByNumericID) { + graphID := graphID + + wg.Add(1) + go func() { + defer wg.Done() + + log.Debugf("fetching macaroons for org %s", graphID) + newToksStr, err := mintToken(ctx, c, graphID) + if err != nil { + addErr(fmt.Errorf("failed to get macaroons for org %s: %w", graphID, err)) + return + } + + newToks, err := macaroon.Parse(newToksStr) + if err != nil { + addErr(fmt.Errorf("bad macaroons for org %s: %w", graphID, err)) + return + } + + for _, newTok := range newToks { + m, err := macaroon.Decode(newTok) + if err != nil { + addErr(fmt.Errorf("bad macaroon for org %s: %w", graphID, err)) + return + } + + mStr, err := m.String() + if err != nil { + addErr(fmt.Errorf("failed encoding macaroon for org %s: %w", graphID, err)) + return + } + + addMac(mStr) + } + }() + } + wg.Wait() + + if slices.Equal(macToks, t.GetMacaroonTokens()) { + return false, wgErr + } + + t.ReplaceMacaroonTokens(macToks) + + return true, wgErr +} + +// orgFetcher allows us to stub out gql calls in tests +type orgFetcher func(context.Context, *fly.Client) (map[uint64]string, error) + +func defaultOrgFetcher(ctx context.Context, c *fly.Client) (map[uint64]string, error) { + orgs, err := c.GetOrganizations(ctx) + if err != nil { + return nil, err + } + + graphIDByNumericID := make(map[uint64]string, len(orgs)) + for _, org := range orgs { + if uintID, err := strconv.ParseUint(org.InternalNumericID, 10, 64); err == nil { + graphIDByNumericID[uintID] = org.ID + } + } + + return graphIDByNumericID, nil +} + +// tokenMinter allows us to stub out gql calls in tests +type tokenMinter func(context.Context, *fly.Client, string) (string, error) + +func defaultTokenMinter(ctx context.Context, c *fly.Client, id string) (string, error) { + resp, err := gql.CreateLimitedAccessToken(ctx, c.GenqClient, "flyctl", id, "deploy_organization", &gql.LimitedAccessTokenOptions{}, "10m") + if err != nil { + return "", err + } + + return resp.CreateLimitedAccessToken.GetLimitedAccessToken().TokenHeader, nil +} diff --git a/internal/config/tokens_test.go b/internal/config/tokens_test.go new file mode 100644 index 0000000000..2bff1e07c7 --- /dev/null +++ b/internal/config/tokens_test.go @@ -0,0 +1,198 @@ +package config + +import ( + "context" + "errors" + "os" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "github.com/superfly/fly-go" + "github.com/superfly/fly-go/tokens" + "github.com/superfly/flyctl/internal/logger" + "github.com/superfly/macaroon" + "github.com/superfly/macaroon/flyio" + "github.com/superfly/macaroon/resset" +) + +func TestFetchOrgTokens(t *testing.T) { + ctx := logger.NewContext(context.Background(), logger.New(os.Stdout, logger.Debug, true)) + + // no tokens + created, err := doFetchOrgTokens(ctx, &tokens.Tokens{}, nil, nil) + require.False(t, created) + require.NoError(t, err) + + // no macaroons + created, err = doFetchOrgTokens(ctx, tokens.Parse("fo1_hi"), nil, nil) + require.False(t, created) + require.NoError(t, err) + + // no user token + created, err = doFetchOrgTokens(ctx, tokens.Parse("fm2_hi"), nil, nil) + require.False(t, created) + require.NoError(t, err) + + // basic case + toks := fakeTokens(t, "fo1_hi", 1) + fetchOrgs := fakeOrgFetcher(map[uint64]string{1: "org1", 2: "org2"}, nil) + mintToken := fakeOrgTokenMinter(t, "org2", 2) + created, err = doFetchOrgTokens(ctx, toks, fetchOrgs, mintToken) + require.True(t, created) + require.NoError(t, err) + assertTokenOrgs(t, toks, 1, 2) + + // fetchOrgs error + toks = fakeTokens(t, "fo1_hi", 1) + foErr := errors.New("my error") + fetchOrgs = fakeOrgFetcher(nil, foErr) + created, err = doFetchOrgTokens(ctx, toks, fetchOrgs, nil) + require.False(t, created) + require.ErrorIs(t, err, foErr) + + // partial success + toks = fakeTokens(t, "fo1_hi", 1) + fetchOrgs = fakeOrgFetcher(map[uint64]string{1: "org1", 2: "org2", 3: "org3"}, nil) + fotErr := errors.New("my error") + mintToken = fakeTokenMinter( + fakeTokenHeader(t, "", 2), + fotErr, + ) + created, err = doFetchOrgTokens(ctx, toks, fetchOrgs, mintToken) + require.True(t, created) + require.ErrorIs(t, err, fotErr) + assertTokenOrgs(t, toks, 1, 2) + + // prune tokens for orgs that user isn't member of + toks = fakeTokens(t, "fo1_hi", 1, 2) + fetchOrgs = fakeOrgFetcher(map[uint64]string{1: "org1"}, nil) + created, err = doFetchOrgTokens(ctx, toks, fetchOrgs, nil) + require.True(t, created) + require.NoError(t, err) + assertTokenOrgs(t, toks, 1) +} + +func fakeOrgFetcher(orgs map[uint64]string, err error) orgFetcher { + return func(context.Context, *fly.Client) (map[uint64]string, error) { return orgs, err } +} + +func fakeOrgTokenMinter(tb testing.TB, expectedGraphID string, oid uint64) tokenMinter { + tb.Helper() + return func(_ context.Context, _ *fly.Client, graphID string) (string, error) { + require.Equal(tb, expectedGraphID, graphID) + return fakeTokenHeader(tb, "", oid), nil + } +} + +func fakeTokenMinter(hdrsOrErrors ...any) tokenMinter { + return func(context.Context, *fly.Client, string) (string, error) { + if len(hdrsOrErrors) == 0 { + panic("unexpected call to fakeTokenMinter") + } + + hdrOrErr := hdrsOrErrors[0] + hdrsOrErrors = hdrsOrErrors[1:] + + switch hoe := hdrOrErr.(type) { + case error: + return "", hoe + case string: + return hoe, nil + default: + panic("unexpected type") + } + } +} + +var ( + permKID = []byte("hello") + permK = macaroon.NewSigningKey() + authK = macaroon.NewEncryptionKey() +) + +func fakeTokens(tb testing.TB, userToken string, oids ...uint64) *tokens.Tokens { + tb.Helper() + + return tokens.Parse(fakeTokenHeader(tb, userToken, oids...)) +} + +func fakeTokenHeader(tb testing.TB, userToken string, oids ...uint64) string { + tb.Helper() + + macs := fakeMacaroons(tb, oids...) + toks := make([][]byte, 0, len(macs)) + for _, m := range macs { + tok, err := m.Encode() + require.NoError(tb, err) + toks = append(toks, tok) + } + + hdr := macaroon.ToAuthorizationHeader(toks...) + + if userToken != "" { + if len(toks) > 0 { + hdr += "," + userToken + } else { + hdr += userToken + } + } + + return hdr +} + +func fakeMacaroons(tb testing.TB, oids ...uint64) []*macaroon.Macaroon { + tb.Helper() + + toks := make([]*macaroon.Macaroon, 0, len(oids)*2) + for _, oid := range oids { + perm := fakePermissionToken(tb, &flyio.Organization{ID: oid, Mask: resset.ActionAll}) + auth := fakeAuthToken(tb, perm) + toks = append(toks, perm, auth) + } + + return toks +} + +func fakePermissionToken(tb testing.TB, cavs ...macaroon.Caveat) *macaroon.Macaroon { + tb.Helper() + + perm, err := macaroon.New(permKID, flyio.LocationPermission, permK) + require.NoError(tb, err) + require.NoError(tb, perm.Add(cavs...)) + return perm +} + +func fakeAuthToken(tb testing.TB, perm *macaroon.Macaroon) *macaroon.Macaroon { + tb.Helper() + + require.NoError(tb, perm.Add3P(authK, flyio.LocationAuthentication)) + ticket, err := perm.ThirdPartyTicket(flyio.LocationAuthentication) + require.NoError(tb, err) + _, auth, err := macaroon.DischargeTicket(authK, flyio.LocationAuthentication, ticket) + require.NoError(tb, err) + return auth +} + +func assertTokenOrgs(tb testing.TB, toks *tokens.Tokens, expectedOIDs ...uint64) { + tb.Helper() + + actualOIDs := make([]uint64, 0, len(expectedOIDs)) + for _, mt := range toks.GetMacaroonTokens() { + mtoks, err := macaroon.Parse(mt) + require.NoError(tb, err) + require.Equal(tb, 1, len(mtoks)) + macs, _, _, _, err := macaroon.FindPermissionAndDischargeTokens(mtoks, flyio.LocationPermission) + require.NoError(tb, err) + if len(macs) != 1 { + continue + } + oid, err := flyio.OrganizationScope(&macs[0].UnsafeCaveats) + require.NoError(tb, err) + actualOIDs = append(actualOIDs, oid) + } + + slices.Sort(expectedOIDs) + slices.Sort(actualOIDs) + require.Equal(tb, expectedOIDs, actualOIDs) +}