Skip to content

Commit

Permalink
RPC for sending tokens to agent
Browse files Browse the repository at this point in the history
  • Loading branch information
btoews committed Apr 11, 2024
1 parent 0549b55 commit 1073034
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 44 deletions.
48 changes: 44 additions & 4 deletions agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/superfly/flyctl/agent/internal/proto"
"github.com/superfly/flyctl/gql"
"github.com/superfly/flyctl/internal/buildinfo"
"github.com/superfly/flyctl/internal/config"
"github.com/superfly/flyctl/internal/flag"
"github.com/superfly/flyctl/internal/logger"
"github.com/superfly/flyctl/internal/sentry"
Expand Down Expand Up @@ -127,14 +128,53 @@ const (
)

type Client struct {
network string
address string
dialer net.Dialer
network string
address string
dialer net.Dialer
agentRefusedTokens bool
}

var errDone = errors.New("done")

func (c *Client) do(parent context.Context, fn func(net.Conn) error) (err error) {
func (c *Client) do(ctx context.Context, fn func(net.Conn) error) (err error) {
if c.agentRefusedTokens {
return c.doNoTokens(ctx, fn)
}

toks := config.Tokens(ctx)
if toks.All() == "" {
return c.doNoTokens(ctx, fn)
}

var tokArgs []string
if fcf := toks.FromConfigFile; fcf != "" {
tokArgs = append(tokArgs, "cfg", fcf)
} else {
tokArgs = append(tokArgs, "str", toks.All())
}

return c.doNoTokens(ctx, func(conn net.Conn) error {
if err := proto.Write(conn, "set-token", tokArgs...); err != nil {
return err
}

data, err := proto.Read(conn)

switch {
case err == nil && string(data) == "ok":
return fn(conn)
case err != nil:
return err
case isError(data):
c.agentRefusedTokens = true
return c.do(ctx, fn)
default:
return err
}
})
}

func (c *Client) doNoTokens(parent context.Context, fn func(net.Conn) error) (err error) {
var conn net.Conn
if conn, err = c.dialContext(parent); err != nil {
return err
Expand Down
66 changes: 55 additions & 11 deletions agent/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ import (

"github.com/azazeal/pause"
fly "github.com/superfly/fly-go"
"github.com/superfly/fly-go/tokens"
"github.com/superfly/flyctl/agent"
"github.com/superfly/flyctl/internal/config"
"github.com/superfly/flyctl/internal/env"
"github.com/superfly/flyctl/internal/flyutil"
"github.com/superfly/flyctl/internal/sentry"
"github.com/superfly/flyctl/internal/wireguard"
"github.com/superfly/flyctl/wg"
Expand All @@ -25,7 +28,6 @@ import (
type Options struct {
Socket string
Logger *log.Logger
Client *fly.Client
Background bool
ConfigFile string
ConfigWebsockets bool
Expand All @@ -49,11 +51,19 @@ func Run(ctx context.Context, opt Options) (err error) {
return
}

toks := config.Tokens(ctx)

monitorCtx, cancelMonitor := context.WithCancel(ctx)
config.MonitorTokens(monitorCtx, toks, nil)

err = (&server{
Options: opt,
listener: l,
currentChange: latestChangeAt,
tunnels: make(map[string]*wg.Tunnel),
Options: opt,
listener: l,
runCtx: ctx,
currentChange: latestChangeAt,
tunnels: make(map[string]*wg.Tunnel),
tokens: toks,
cancelTokenMonitoring: cancelMonitor,
}).serve(ctx, l)

return
Expand Down Expand Up @@ -105,9 +115,12 @@ type server struct {

listener net.Listener

mu sync.Mutex
currentChange time.Time
tunnels map[string]*wg.Tunnel
runCtx context.Context
mu sync.Mutex
currentChange time.Time
tunnels map[string]*wg.Tunnel
tokens *tokens.Tokens
cancelTokenMonitoring func()
}

type terminateError struct{ error }
Expand Down Expand Up @@ -206,7 +219,7 @@ func (s *server) checkForConfigChange() (err error) {
return
}

func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle bool) (tunnel *wg.Tunnel, err error) {
func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle bool, client *fly.Client) (tunnel *wg.Tunnel, err error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -217,7 +230,7 @@ func (s *server) buildTunnel(ctx context.Context, org *fly.Organization, recycle
}

var state *wg.WireGuardState
if state, err = wireguard.StateForOrg(ctx, s.Client, org, os.Getenv("FLY_AGENT_WG_REGION"), "", recycle); err != nil {
if state, err = wireguard.StateForOrg(ctx, client, org, os.Getenv("FLY_AGENT_WG_REGION"), "", recycle); err != nil {
return
}

Expand Down Expand Up @@ -349,7 +362,7 @@ func (s *server) clean(ctx context.Context) {
break
}

if err := wireguard.PruneInvalidPeers(ctx, s.Client); err != nil {
if err := wireguard.PruneInvalidPeers(ctx, s.GetClient(ctx)); err != nil {
s.printf("failed pruning invalid peers: %v", err)
}

Expand All @@ -361,6 +374,37 @@ func (s *server) clean(ctx context.Context) {
}
}

// GetClient returns an API client that uses the server's tokens. Sessions may
// have their own tokens, so should use session.getClient instead.
func (s *server) GetClient(ctx context.Context) *fly.Client {
s.mu.Lock()
defer s.mu.Unlock()

return flyutil.NewClientFromOptions(ctx, fly.ClientOptions{Tokens: s.tokens})
}

// UpdateTokensFromClient replaces the server's tokens with those from the
// client if the new ones seem better. Specifically, if the agent was started
// with `FLY_API_TOKEN`, but a later client is using tokens form a config file.
func (s *server) UpdateTokensFromClient(t *tokens.Tokens) {
s.mu.Lock()
defer s.mu.Unlock()

if s.tokens.FromConfigFile != "" || t.FromConfigFile == "" {
return
}

s.print("received new tokens from client")

s.cancelTokenMonitoring()

monitorCtx, cancelMonitor := context.WithCancel(s.runCtx)
config.MonitorTokens(monitorCtx, t, nil)

s.tokens = t
s.cancelTokenMonitoring = cancelMonitor
}

func (s *server) print(v ...interface{}) {
s.Logger.Print(v...)
}
Expand Down
106 changes: 82 additions & 24 deletions agent/server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ import (
"golang.org/x/sync/errgroup"

fly "github.com/superfly/fly-go"
"github.com/superfly/fly-go/tokens"
"github.com/superfly/flyctl/agent"
"github.com/superfly/flyctl/agent/internal/proto"
"github.com/superfly/flyctl/wg"

"github.com/superfly/flyctl/internal/buildinfo"
"github.com/superfly/flyctl/internal/config"
"github.com/superfly/flyctl/internal/flyutil"
)

type id uint64
Expand All @@ -38,6 +41,7 @@ type session struct {
conn net.Conn
logger *log.Logger
id id
tokens *tokens.Tokens
}

var errUnsupportedCommand = errors.New("unsupported command")
Expand Down Expand Up @@ -81,6 +85,10 @@ func runSession(ctx context.Context, srv *server, conn net.Conn, id id) {
return
}

s.runCommand(ctx)
}

func (s *session) runCommand(ctx context.Context) {
buf, err := proto.Read(s.conn)
if len(buf) > 0 {
s.logger.Printf("<- (% 5d) %q", len(buf), redact(buf))
Expand All @@ -95,30 +103,37 @@ func runSession(ctx context.Context, srv *server, conn net.Conn, id id) {
}

args := strings.Split(string(buf), " ")

fn := handlers[args[0]]
if fn == nil {
var handler func(*session, context.Context, ...string)

switch args[0] {
case "kill":
handler = (*session).kill
case "ping":
handler = (*session).ping
case "establish":
handler = (*session).establish
case "reestablish":
handler = (*session).reestablish
case "connect":
handler = (*session).connect
case "probe":
handler = (*session).probe
case "instances":
handler = (*session).instances
case "resolve":
handler = (*session).resolve
case "lookupTxt":
handler = (*session).lookupTxt
case "ping6":
handler = (*session).ping6
case "set-token":
handler = (*session).setToken
default:
s.error(errUnsupportedCommand)

return
}

fn(s, ctx, args[1:]...)
}

type handlerFunc func(*session, context.Context, ...string)

var handlers = map[string]handlerFunc{
"kill": (*session).kill,
"ping": (*session).ping,
"establish": (*session).establish,
"reestablish": (*session).reestablish,
"connect": (*session).connect,
"probe": (*session).probe,
"instances": (*session).instances,
"resolve": (*session).resolve,
"lookupTxt": (*session).lookupTxt,
"ping6": (*session).ping6,
handler(s, ctx, args[1:]...)
}

var errMalformedKill = errors.New("malformed kill command")
Expand Down Expand Up @@ -161,7 +176,7 @@ func (s *session) doEstablish(ctx context.Context, recycle bool, args ...string)
return
}

tunnel, err := s.srv.buildTunnel(ctx, org, recycle)
tunnel, err := s.srv.buildTunnel(ctx, org, recycle, s.getClient(ctx))
if err != nil {
s.error(err)

Expand All @@ -185,7 +200,7 @@ func (s *session) reestablish(ctx context.Context, args ...string) {
var errNoSuchOrg = errors.New("no such organization")

func (s *session) fetchOrg(ctx context.Context, slug string) (*fly.Organization, error) {
orgs, err := s.srv.Client.GetOrganizations(ctx)
orgs, err := s.getClient(ctx).GetOrganizations(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -533,6 +548,43 @@ func (s *session) ping6(ctx context.Context, args ...string) {
}
}

var errMalformedSetToken = errors.New("malformed set-token command")

// setToken instructs the agent which tokens to use for API calls.
func (s *session) setToken(ctx context.Context, args ...string) {
s.exactArgs(2, args, errMalformedSetToken)

switch args[0] {
case "cfg":
tokStr, err := config.ReadAccessToken(args[1])
if err != nil {
s.error(err)
return
}

s.tokens = tokens.Parse(tokStr)
s.tokens.FromConfigFile = args[1]
case "str":
s.tokens = tokens.Parse(args[1])
}

go s.srv.UpdateTokensFromClient(s.tokens)

s.ok()

s.runCommand(ctx)
}

// getClient returns an API client that uses any API tokens sent by the client.
// If not have been sent, it falls back to using the server's tokens.
func (s *session) getClient(ctx context.Context) *fly.Client {
if s.tokens == nil {
return s.srv.GetClient(ctx)
}

return flyutil.NewClientFromOptions(ctx, fly.ClientOptions{Tokens: s.tokens})
}

func (s *session) error(err error) bool {
return s.reply("err", err.Error())
}
Expand Down Expand Up @@ -606,8 +658,14 @@ func isClosed(err error) bool {
return errors.Is(err, net.ErrClosed)
}

var redactRx = regexp.MustCompile(`(PrivateKey|private)":".*?"`)
var (
redactPrivateKeyRx = regexp.MustCompile(`(PrivateKey|private)":".*?"`)
redactTokenRx = regexp.MustCompile(`(fo1_|fm1[ar]_|fm2_)[a-zA-Z0-9/+_-]+=*`)
)

func redact(buf []byte) []byte {
return redactRx.ReplaceAll(buf, []byte(`PrivateKey":"[redacted]"`))
buf = redactPrivateKeyRx.ReplaceAll(buf, []byte(`PrivateKey":"[redacted]"`))
buf = redactTokenRx.ReplaceAll(buf, []byte(`$1[redacted]`))
return buf

}
6 changes: 1 addition & 5 deletions internal/command/agent/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,11 @@ func run(ctx context.Context) error {
}
defer closeLogger()

apiClient := fly.ClientFromContext(ctx)
if !apiClient.Authenticated() {
if config.Tokens(ctx).GraphQL() == "" {
logger.Println(fly.ErrNoAuthToken)
return fly.ErrNoAuthToken
}

config.MonitorTokens(ctx, config.Tokens(ctx), nil)

unlock, err := lock(ctx, logger)
if err != nil {
return err
Expand All @@ -67,7 +64,6 @@ func run(ctx context.Context) error {
opt := server.Options{
Socket: socketPath(ctx),
Logger: logger,
Client: apiClient,
Background: logPath != "",
ConfigFile: state.ConfigFile(ctx),
ConfigWebsockets: viper.GetBool(flyctl.ConfigWireGuardWebsockets),
Expand Down

0 comments on commit 1073034

Please sign in to comment.