Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

peer: simplify MsgRouter with new fn tools #9010

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ require (
github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb
github.com/lightningnetwork/lnd/cert v1.2.2
github.com/lightningnetwork/lnd/clock v1.1.1
github.com/lightningnetwork/lnd/fn v1.2.0
github.com/lightningnetwork/lnd/fn v1.2.1
github.com/lightningnetwork/lnd/healthcheck v1.2.5
github.com/lightningnetwork/lnd/kvdb v1.4.10
github.com/lightningnetwork/lnd/queue v1.1.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf
github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U=
github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0=
github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ=
github.com/lightningnetwork/lnd/fn v1.2.0 h1:YTb2m8NN5ZiJAskHeBZAmR1AiPY8SXziIYPAX1VI/ZM=
github.com/lightningnetwork/lnd/fn v1.2.0/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0=
github.com/lightningnetwork/lnd/fn v1.2.1 h1:pPsVGrwi9QBwdLJzaEGK33wmiVKOxs/zc8H7+MamFf0=
github.com/lightningnetwork/lnd/fn v1.2.1/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0=
github.com/lightningnetwork/lnd/healthcheck v1.2.5 h1:aTJy5xeBpcWgRtW/PGBDe+LMQEmNm/HQewlQx2jt7OA=
github.com/lightningnetwork/lnd/healthcheck v1.2.5/go.mod h1:G7Tst2tVvWo7cx6mSBEToQC5L1XOGxzZTPB29g9Rv2I=
github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kfmOrqHbY5zoY=
Expand Down
261 changes: 261 additions & 0 deletions peer/msg_router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package peer

import (
"fmt"
"maps"
"sync"

"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnwire"
)

var (
// ErrDuplicateEndpoint is returned when an endpoint is registered with
// a name that already exists.
ErrDuplicateEndpoint = fmt.Errorf("endpoint already registered")

// ErrUnableToRouteMsg is returned when a message is unable to be
// routed to any endpoints.
ErrUnableToRouteMsg = fmt.Errorf("unable to route message")
)

// EndPointName is the name of a given endpoint. This MUST be unique across all
// registered endpoints.
type EndPointName = string

// MsgEndpoint is an interface that represents a message endpoint, or the
// sub-system that will handle processing an incoming wire message.
type MsgEndpoint interface {
// Name returns the name of this endpoint. This MUST be unique across
// all registered endpoints.
Name() EndPointName

// CanHandle returns true if the target message can be routed to this
// endpoint.
CanHandle(msg lnwire.Message) bool

// SendMessage handles the target message, and returns true if the
// message was able to be processed.
SendMessage(msg lnwire.Message) bool
}

// MsgRouter is an interface that represents a message router, which is generic
// sub-system capable of routing any incoming wire message to a set of
// registered endpoints.
//
// TODO(roasbeef): move to diff sub-system?
type MsgRouter interface {
// RegisterEndpoint registers a new endpoint with the router. If a
// duplicate endpoint exists, an error is returned.
RegisterEndpoint(MsgEndpoint) error

// UnregisterEndpoint unregisters the target endpoint from the router.
UnregisterEndpoint(EndPointName) error

// RouteMsg attempts to route the target message to a registered
// endpoint. If ANY endpoint could handle the message, then nil is
// returned. Otherwise, ErrUnableToRouteMsg is returned.
RouteMsg(lnwire.Message) error

// Start starts the peer message router.
Start()

// Stop stops the peer message router.
Stop()
}

// sendQuery sends a query to the main event loop, and returns the response.
func sendQuery[Q any, R any](sendChan chan fn.Req[Q, R], queryArg Q,
quit chan struct{}) fn.Result[R] {

query, respChan := fn.NewReq[Q, R](queryArg)

if !fn.SendOrQuit(sendChan, query, quit) {
return fn.Errf[R]("router shutting down")
}

return fn.NewResult(fn.RecvResp(respChan, nil, quit))
}

// sendQueryErr is a helper function based on sendQuery that can be used when
// the query only needs an error response.
func sendQueryErr[Q any](sendChan chan fn.Req[Q, error], queryArg Q,
quitChan chan struct{}) error {

return fn.ElimEither(
fn.Iden, fn.Iden,
sendQuery(sendChan, queryArg, quitChan).Either,
)
}

// EndpointsMap is a map of all registered endpoints.
type EndpointsMap map[EndPointName]MsgEndpoint

// MultiMsgRouter is a type of message router that is capable of routing new
// incoming messages, permitting a message to be routed to multiple registered
// endpoints.
type MultiMsgRouter struct {
startOnce sync.Once
stopOnce sync.Once

// registerChan is the channel that all new endpoints will be sent to.
registerChan chan fn.Req[MsgEndpoint, error]

// unregisterChan is the channel that all endpoints that are to be
// removed are sent to.
unregisterChan chan fn.Req[EndPointName, error]

// msgChan is the channel that all messages will be sent to for
// processing.
msgChan chan fn.Req[lnwire.Message, error]

// endpointsQueries is a channel that all queries to the endpoints map
// will be sent to.
endpointQueries chan fn.Req[MsgEndpoint, EndpointsMap]

wg sync.WaitGroup
quit chan struct{}
}

// NewMultiMsgRouter creates a new instance of a peer message router.
func NewMultiMsgRouter() *MultiMsgRouter {
return &MultiMsgRouter{
registerChan: make(chan fn.Req[MsgEndpoint, error]),
unregisterChan: make(chan fn.Req[EndPointName, error]),
msgChan: make(chan fn.Req[lnwire.Message, error]),
endpointQueries: make(chan fn.Req[MsgEndpoint, EndpointsMap]),
quit: make(chan struct{}),
}
}

// Start starts the peer message router.
func (p *MultiMsgRouter) Start() {
peerLog.Infof("Starting MsgRouter")

p.startOnce.Do(func() {
p.wg.Add(1)
go p.msgRouter()
})
}

// Stop stops the peer message router.
func (p *MultiMsgRouter) Stop() {
peerLog.Infof("Stopping MsgRouter")

p.stopOnce.Do(func() {
close(p.quit)
p.wg.Wait()
})
}

// RegisterEndpoint registers a new endpoint with the router. If a duplicate
// endpoint exists, an error is returned.
func (p *MultiMsgRouter) RegisterEndpoint(endpoint MsgEndpoint) error {
return sendQueryErr(p.registerChan, endpoint, p.quit)
}

// UnregisterEndpoint unregisters the target endpoint from the router.
func (p *MultiMsgRouter) UnregisterEndpoint(name EndPointName) error {
return sendQueryErr(p.unregisterChan, name, p.quit)
}

// RouteMsg attempts to route the target message to a registered endpoint. If
// ANY endpoint could handle the message, then nil is returned.
func (p *MultiMsgRouter) RouteMsg(msg lnwire.Message) error {
return sendQueryErr(p.msgChan, msg, p.quit)
}

// Endpoints returns a list of all registered endpoints.
func (p *MultiMsgRouter) endpoints() fn.Result[EndpointsMap] {
return sendQuery(p.endpointQueries, nil, p.quit)
}

// msgRouter is the main goroutine that handles all incoming messages.
func (p *MultiMsgRouter) msgRouter() {
defer p.wg.Done()

// endpoints is a map of all registered endpoints.
endpoints := make(map[EndPointName]MsgEndpoint)

for {
select {
// A new endpoint was just sent in, so we'll add it to our set
// of registered endpoints.
case newEndpointMsg := <-p.registerChan:
endpoint := newEndpointMsg.Request

peerLog.Infof("MsgRouter: registering new "+
"MsgEndpoint(%s)", endpoint.Name())

// If this endpoint already exists, then we'll return
// an error as we require unique names.
if _, ok := endpoints[endpoint.Name()]; ok {
peerLog.Errorf("MsgRouter: rejecting "+
"duplicate endpoint: %v",
endpoint.Name())

newEndpointMsg.Resolve(ErrDuplicateEndpoint)

continue
}

endpoints[endpoint.Name()] = endpoint

newEndpointMsg.Resolve(nil)

// A request to unregister an endpoint was just sent in, so
// we'll attempt to remove it.
case endpointName := <-p.unregisterChan:
delete(endpoints, endpointName.Request)

peerLog.Infof("MsgRouter: unregistering "+
"MsgEndpoint(%s)", endpointName.Request)

endpointName.Resolve(nil)

// A new message was just sent in. We'll attempt to route it to
// all the endpoints that can handle it.
case msgQuery := <-p.msgChan:
msg := msgQuery.Request

// Loop through all the endpoints and send the message
// to those that can handle it the message.
var couldSend bool
for _, endpoint := range endpoints {
if endpoint.CanHandle(msg) {
peerLog.Tracef("MsgRouter: sending "+
"msg %T to endpoint %s", msg,
endpoint.Name())

sent := endpoint.SendMessage(msg)
couldSend = couldSend || sent
}
}

var err error
if !couldSend {
peerLog.Tracef("MsgRouter: unable to route "+
"msg %T", msg)

err = ErrUnableToRouteMsg
}

msgQuery.Resolve(err)

// A query for the endpoint state just came in, we'll send back
// a copy of our current state.
case endpointQuery := <-p.endpointQueries:
endpointsCopy := make(EndpointsMap, len(endpoints))
maps.Copy(endpointsCopy, endpoints)

endpointQuery.Resolve(endpointsCopy)

case <-p.quit:
return
}
}
}

// A compile time check to ensure MultiMsgRouter implements the MsgRouter
// interface.
var _ MsgRouter = (*MultiMsgRouter)(nil)
Loading
Loading