Skip to content

Commit

Permalink
feat: extend daemon to work as a stateful profile store
Browse files Browse the repository at this point in the history
  • Loading branch information
tinyzimmer committed Nov 9, 2023
1 parent 9c68a39 commit 03a5941
Show file tree
Hide file tree
Showing 6 changed files with 433 additions and 168 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.2.1-beta.2
github.com/webmeshproj/api v0.11.2
github.com/webmeshproj/api v0.11.4-0.20231109201546-44726ffeea69
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
golang.org/x/net v0.17.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,8 @@ github.com/warpfork/go-wish v0.0.0-20220906213052-39a1cc7a02d0 h1:GDDkbFiaK8jsSD
github.com/warpfork/go-wish v0.0.0-20220906213052-39a1cc7a02d0/go.mod h1:x6AKhvSSexNrVSrViXSHUEbICjmGXhtgABaHIySUSGw=
github.com/webmeshproj/api v0.11.2 h1:121MlNlwWNU2II3gF2v9I/x342s/yBKvDs+JqmjKDfY=
github.com/webmeshproj/api v0.11.2/go.mod h1:xuYk93HM4aZWWlTh96Z2nIg1YhqcRG36nOfcifzHeM4=
github.com/webmeshproj/api v0.11.4-0.20231109201546-44726ffeea69 h1:lnpiABZ5U10GCRacgmvUK9tCuBLTJqQg0vWmm4r8N6c=
github.com/webmeshproj/api v0.11.4-0.20231109201546-44726ffeea69/go.mod h1:xuYk93HM4aZWWlTh96Z2nIg1YhqcRG36nOfcifzHeM4=
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1 h1:EKhdznlJHPMoKr0XTrX+IlJs1LH3lyx2nfr1dOlZ79k=
github.com/whyrusleeping/go-keyspace v0.0.0-20160322163242-5b898ac5add1/go.mod h1:8UvriyWtv5Q5EOgjHaSseUEdkQfvwFv1I/In/O2M9gc=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
Expand Down
148 changes: 69 additions & 79 deletions pkg/cmd/daemoncmd/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
v1 "github.com/webmeshproj/api/go/v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

"github.com/webmeshproj/webmesh/pkg/config"
"github.com/webmeshproj/webmesh/pkg/context"
Expand All @@ -38,6 +37,7 @@ import (
"github.com/webmeshproj/webmesh/pkg/meshnet/endpoints"
"github.com/webmeshproj/webmesh/pkg/meshnet/system/firewall"
"github.com/webmeshproj/webmesh/pkg/storage"
"github.com/webmeshproj/webmesh/pkg/storage/errors"
"github.com/webmeshproj/webmesh/pkg/storage/types"
)

Expand All @@ -54,41 +54,29 @@ var (
ErrConnected = status.Errorf(codes.FailedPrecondition, "connected to the specified network")
)

// ParamsDir is the relative path where connection request parameters are stored
// for each connection ID.
const ParamsDir = "params"

// ParamsFile returns the relative path to the parameters file for the given connection ID.
func ParamsFile(connID string) string {
return filepath.Join(ParamsDir, fmt.Sprintf("%s.json", connID))
}

// ConnManager manages the connections for the daemon.
type ConnManager struct {
nodeID types.NodeID
key crypto.PrivateKey
conf Config
conns map[string]embed.Node
ports map[uint16]string
utuns map[uint16]string
log *slog.Logger
mu sync.RWMutex
nodeID types.NodeID
key crypto.PrivateKey
conf Config
profiles ProfileStore
conns map[string]embed.Node
ports map[uint16]string
utuns map[uint16]string
log *slog.Logger
mu sync.RWMutex
}

// NewConnManager creates a new connection manager.
func NewConnManager(conf Config) (*ConnManager, error) {
log := conf.NewLogger().With("appdaemon", "connmgr")
key, err := conf.LoadKey(log)
if err != nil {
return nil, fmt.Errorf("failed to load key: %w", err)
return nil, fmt.Errorf("load key: %w", err)
}
if conf.Persistence.Path != "" {
err := os.MkdirAll(filepath.Join(conf.Persistence.Path, ParamsDir), 0700)
if err != nil {
return nil, fmt.Errorf("failed to create params directory: %w", err)
}
// TODO: Store connection status in this directory and restart connections
// that were running when the daemon was stopped.
profiles, err := NewProfileStore(conf.Persistence.Path)
if err != nil {
return nil, fmt.Errorf("create profile store: %w", err)
}
var nodeID types.NodeID
if conf.NodeID != "" {
Expand All @@ -97,13 +85,14 @@ func NewConnManager(conf Config) (*ConnManager, error) {
nodeID = types.NodeID(key.ID())
}
return &ConnManager{
nodeID: nodeID,
key: key,
conf: conf,
conns: make(map[string]embed.Node),
ports: make(map[uint16]string),
utuns: make(map[uint16]string),
log: log,
nodeID: nodeID,
key: key,
conf: conf,
profiles: profiles,
conns: make(map[string]embed.Node),
ports: make(map[uint16]string),
utuns: make(map[uint16]string),
log: log,
}, nil
}

Expand All @@ -118,23 +107,28 @@ func (m *ConnManager) PublicKey() string {
return encoded
}

// Close closes the connection manager and all connections.
// Profiles returns the profiles store.
func (m *ConnManager) Profiles() ProfileStore {
return m.profiles
}

// Close closes the connection manager and all connections. It is not
// safe to use the connection manager after calling Close.
func (m *ConnManager) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
defer m.profiles.Close()
for id, conn := range m.conns {
m.log.Info("Stopping connection", "id", id)
err := conn.Stop(context.WithLogger(context.Background(), m.log))
if err != nil {
m.log.Error("Failed to stop connection", "error", err.Error())
}
}
m.conns = nil
m.ports = nil
return nil
}

// ConnIDs returns the IDs of the connections.
// ConnIDs returns the IDs of all currently active connections.
func (m *ConnManager) ConnIDs() []string {
m.mu.RLock()
defer m.mu.RUnlock()
Expand All @@ -149,29 +143,34 @@ func (m *ConnManager) Get(connID string) (embed.Node, bool) {
return n, ok
}

// DataDir returns the data directory for the given connection ID.
func (m *ConnManager) DataDir(connID string) string {
return filepath.Join(m.conf.Persistence.Path, connID)
// GetStatus returns the status of the connection for the given ID.
func (m *ConnManager) GetStatus(connID string) v1.DaemonConnStatus {
m.mu.RLock()
defer m.mu.RUnlock()
c, ok := m.conns[connID]
if !ok {
return v1.DaemonConnStatus_DISCONNECTED
}
if c.MeshNode().Started() {
return v1.DaemonConnStatus_CONNECTED
}
return v1.DaemonConnStatus_CONNECTING
}

// StoredConns returns all connection IDs known to the persistence layer.
func (m *ConnManager) StoredConns() ([]string, error) {
if m.conf.Persistence.Path == "" {
return nil, nil
}
// GetMeshNode returns the full mesh node for the given ID.
func (m *ConnManager) GetMeshNode(ctx context.Context, connID string) (types.MeshNode, error) {
m.mu.RLock()
defer m.mu.RUnlock()
contents, err := os.ReadDir(m.conf.Persistence.Path)
if err != nil {
return nil, fmt.Errorf("read dir: %w", err)
}
var ids []string
for _, entry := range contents {
if entry.IsDir() {
ids = append(ids, entry.Name())
}
conn, ok := m.conns[connID]
if !ok {
return types.MeshNode{}, ErrNotConnected
}
return ids, nil
return conn.Storage().MeshDB().Peers().Get(ctx, conn.MeshNode().ID())
}

// DataDir returns the data directory for the given connection ID.
func (m *ConnManager) DataDir(connID string) string {
return filepath.Join(m.conf.Persistence.Path, connID)
}

// DropStorage drops storage for the connection with the given ID.
Expand All @@ -189,10 +188,6 @@ func (m *ConnManager) DropStorage(ctx context.Context, connID string) error {
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove all: %w", err)
}
err = os.Remove(filepath.Join(m.conf.Persistence.Path, ParamsFile(connID)))
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove params: %w", err)
}
return nil
}

Expand Down Expand Up @@ -223,20 +218,15 @@ func (m *ConnManager) NewConn(ctx context.Context, req *v1.ConnectRequest) (id s
if err != nil {
return "", nil, err
}
if m.conf.Persistence.Path != "" {
m.log.Info("Saving connection parameters in case of restart", "id", connID)
paramsJSON, err := protojson.Marshal(req)
if err != nil {
return "", nil, status.Errorf(codes.Internal, "failed to marshal connection parameters: %v", err)
}
paramsPath := filepath.Join(m.conf.Persistence.Path, ParamsFile(connID))
err = os.WriteFile(paramsPath, paramsJSON, 0600)
if err != nil {
return "", nil, status.Errorf(codes.Internal, "failed to write connection parameters: %v", err)
m.log.Info("Creating new webmesh node", "id", connID, "port", port)
profile, err := m.profiles.Get(ctx, ProfileID(connID))
if err != nil {
if errors.IsNotFound(err) {
return "", nil, status.Errorf(codes.NotFound, "profile not found")
}
return "", nil, status.Errorf(codes.Internal, "failed to get profile: %v", err)
}
m.log.Info("Creating new webmesh node", "id", connID, "port", port)
cfg, err := m.buildConnConfig(ctx, req, connID, port)
cfg, err := m.buildConnConfig(ctx, profile.ConnectionParameters, connID, port)
if err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -343,7 +333,7 @@ func (m *ConnManager) assignUTUNIndex(connID string) (uint16, error) {
}
}

func (m *ConnManager) buildConnConfig(ctx context.Context, req *v1.ConnectRequest, connID string, listenPort uint16) (*config.Config, error) {
func (m *ConnManager) buildConnConfig(ctx context.Context, req *v1.ConnectionParameters, connID string, listenPort uint16) (*config.Config, error) {
conf := config.NewDefaultConfig(m.nodeID.String())
conf.Global.LogLevel = m.conf.LogLevel
conf.Global.LogFormat = m.conf.LogFormat
Expand Down Expand Up @@ -435,22 +425,22 @@ func (m *ConnManager) buildConnConfig(ctx context.Context, req *v1.ConnectReques
conf.TLS.InsecureSkipVerify = req.GetTls().GetSkipVerify()
}
switch req.GetAddrType() {
case v1.ConnectRequest_ADDR:
case v1.ConnectionParameters_ADDR:
conf.Mesh.JoinAddresses = req.GetAddrs()
case v1.ConnectRequest_RENDEZVOUS:
case v1.ConnectionParameters_RENDEZVOUS:
conf.Discovery.Discover = true
conf.Discovery.Rendezvous = req.GetAddrs()[0]
case v1.ConnectRequest_MULTIADDR:
case v1.ConnectionParameters_MULTIADDR:
conf.Mesh.JoinMultiaddrs = req.GetAddrs()
}
switch req.GetAuthMethod() {
case v1.NetworkAuthMethod_NO_AUTH:
case v1.NetworkAuthMethod_BASIC:
conf.Auth.Basic.Username = req.GetAuthCredentials()[v1.ConnectRequest_BASIC_USERNAME.String()]
conf.Auth.Basic.Password = req.GetAuthCredentials()[v1.ConnectRequest_BASIC_PASSWORD.String()]
conf.Auth.Basic.Username = req.GetAuthCredentials()[v1.ConnectionParameters_BASIC_USERNAME.String()]
conf.Auth.Basic.Password = req.GetAuthCredentials()[v1.ConnectionParameters_BASIC_PASSWORD.String()]
case v1.NetworkAuthMethod_LDAP:
conf.Auth.LDAP.Username = req.GetAuthCredentials()[v1.ConnectRequest_LDAP_USERNAME.String()]
conf.Auth.LDAP.Password = req.GetAuthCredentials()[v1.ConnectRequest_LDAP_PASSWORD.String()]
conf.Auth.LDAP.Username = req.GetAuthCredentials()[v1.ConnectionParameters_LDAP_USERNAME.String()]
conf.Auth.LDAP.Password = req.GetAuthCredentials()[v1.ConnectionParameters_LDAP_PASSWORD.String()]
case v1.NetworkAuthMethod_MTLS:
conf.Auth.MTLS.CertData = req.GetTls().GetCertData()
conf.Auth.MTLS.KeyData = req.GetTls().GetKeyData()
Expand Down

0 comments on commit 03a5941

Please sign in to comment.